diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..50ade43fd8c4e76966cfd638f61c68e291213e9b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo_narrow.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..ceaaa12d9f5c80e118b4a9b10a03f75a9ebe17ae --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,17 @@ + + + + +## Why are these changes needed? + + + +## Related issue number (if applicable) + + + +## Checks + +- [ ] I've run `format.sh` to lint the changes in this PR. +- [ ] I've included any doc changes needed. +- [ ] I've made sure the relevant tests are passing (if applicable). diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000000000000000000000000000000000000..8f122caeb6023625437cd1d307025e4f343ee593 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,30 @@ +name: Python package + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e '.[dev]' + - name: Run linter + run: | + pylint -d all -e E0602 ./fastchat/ + - name: Check formatting + run: | + black --check . diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4ae098374064ad074c824a7aee8204dc0dd84262 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +# Python +__pycache__ +*.pyc +*.egg-info +dist +.venv + +# Log +*.log +*.log.* +*.json +!playground/deepspeed_config_s2.json +!playground/deepspeed_config_s3.json + +# Editor +.idea +*.swp + +# Other +.DS_Store +wandb +output +checkpoints_flant5_3b + +# Data +*.pkl +*.csv +tests/state_of_the_union.txt + +# Build +build + +# Image data +serve_images +val2014 +vqa_examples \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..864033fae31be8b04dec1df4eb855cae621ed96c --- /dev/null +++ b/.pylintrc @@ -0,0 +1,449 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party,ray_patches,providers + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat-in-sequence, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + logging-format-interpolation, # FIXME(sky): make pass. + logging-fstring-interpolation, # FIXME(sky): make pass. + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". This option is deprecated +# and it will be removed in Pylint 2.0. +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here we use 4 spaces. +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException + +####### + +# https://github.com/edaniszewski/pylint-quotes#configuration +string-quote=single +triple-quote=double +docstring-quote=double diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 59c3aa3bcd3ed51dbc987768546938bd0f695910..7554c846587a8b66da95870e29c6d2e1adaaa0a2 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,359 @@ --- -title: Bf16 Vs Fp8 -emoji: 🐠 -colorFrom: gray -colorTo: indigo +title: bf16_vs_fp8 +app_file: fastchat/serve/gradio_block_arena_anony.py sdk: gradio sdk_version: 4.41.0 -app_file: app.py -pinned: false --- +# FastChat +| [**Demo**](https://chat.lmsys.org/) | [**Discord**](https://discord.gg/HSWAKCrnFx) | [**X**](https://x.com/lmsysorg) | -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +FastChat is an open platform for training, serving, and evaluating large language model based chatbots. +- FastChat powers Chatbot Arena (https://chat.lmsys.org/), serving over 10 million chat requests for 70+ LLMs. +- Chatbot Arena has collected over 500K human votes from side-by-side LLM battles to compile an online [LLM Elo leaderboard](https://leaderboard.lmsys.org). + +FastChat's core features include: +- The training and evaluation code for state-of-the-art models (e.g., Vicuna, MT-Bench). +- A distributed multi-model serving system with web UI and OpenAI-compatible RESTful APIs. + +## News +- [2024/03] 🔥 We released Chatbot Arena technical [report](https://arxiv.org/abs/2403.04132). +- [2023/09] We released **LMSYS-Chat-1M**, a large-scale real-world LLM conversation dataset. Read the [report](https://arxiv.org/abs/2309.11998). +- [2023/08] We released **Vicuna v1.5** based on Llama 2 with 4K and 16K context lengths. Download [weights](#vicuna-weights). +- [2023/07] We released **Chatbot Arena Conversations**, a dataset containing 33k conversations with human preferences. Download it [here](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations). + +
+More + +- [2023/08] We released **LongChat v1.5** based on Llama 2 with 32K context lengths. Download [weights](#longchat). +- [2023/06] We introduced **MT-bench**, a challenging multi-turn question set for evaluating chatbots. Check out the blog [post](https://lmsys.org/blog/2023-06-22-leaderboard/). +- [2023/06] We introduced **LongChat**, our long-context chatbots and evaluation tools. Check out the blog [post](https://lmsys.org/blog/2023-06-29-longchat/). +- [2023/05] We introduced **Chatbot Arena** for battles among LLMs. Check out the blog [post](https://lmsys.org/blog/2023-05-03-arena). +- [2023/03] We released **Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality**. Check out the blog [post](https://vicuna.lmsys.org). + +
+ + + +## Contents +- [Install](#install) +- [Model Weights](#model-weights) +- [Inference with Command Line Interface](#inference-with-command-line-interface) +- [Serving with Web GUI](#serving-with-web-gui) +- [API](#api) +- [Evaluation](#evaluation) +- [Fine-tuning](#fine-tuning) +- [Citation](#citation) + +## Install + +### Method 1: With pip + +```bash +pip3 install "fschat[model_worker,webui]" +``` + +### Method 2: From source + +1. Clone this repository and navigate to the FastChat folder +```bash +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +``` + +If you are running on Mac: +```bash +brew install rust cmake +``` + +2. Install Package +```bash +pip3 install --upgrade pip # enable PEP 660 support +pip3 install -e ".[model_worker,webui]" +``` + +## Model Weights +### Vicuna Weights +[Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) is based on Llama 2 and should be used under Llama's [model license](https://github.com/facebookresearch/llama/blob/main/LICENSE). + +You can use the commands below to start chatting. It will automatically download the weights from Hugging Face repos. +Downloaded weights are stored in a `.cache` folder in the user's home folder (e.g., `~/.cache/huggingface/hub/`). + +See more command options and how to handle out-of-memory in the "Inference with Command Line Interface" section below. + +**NOTE: `transformers>=4.31` is required for 16K versions.** + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 7B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | +| 7B-16k | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5-16k` | [lmsys/vicuna-7b-v1.5-16k](https://huggingface.co/lmsys/vicuna-7b-v1.5-16k) | +| 13B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-13b-v1.5` | [lmsys/vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5) | +| 13B-16k | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-13b-v1.5-16k` | [lmsys/vicuna-13b-v1.5-16k](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k) | +| 33B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-33b-v1.3` | [lmsys/vicuna-33b-v1.3](https://huggingface.co/lmsys/vicuna-33b-v1.3) | + +**Old weights**: see [docs/vicuna_weights_version.md](docs/vicuna_weights_version.md) for all versions of weights and their differences. + +### Other Models +Besides Vicuna, we also released two additional models: [LongChat](https://lmsys.org/blog/2023-06-29-longchat/) and FastChat-T5. +You can use the commands below to chat with them. They will automatically download the weights from Hugging Face repos. + +| Model | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| LongChat-7B | `python3 -m fastchat.serve.cli --model-path lmsys/longchat-7b-32k-v1.5` | [lmsys/longchat-7b-32k](https://huggingface.co/lmsys/longchat-7b-32k-v1.5) | +| FastChat-T5-3B | `python3 -m fastchat.serve.cli --model-path lmsys/fastchat-t5-3b-v1.0` | [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5-3b-v1.0) | + +## Inference with Command Line Interface + + + +(Experimental Feature: You can specify `--style rich` to enable rich text output and better text streaming quality for some non-ASCII content. This may not work properly on certain terminals.) + +#### Supported Models +FastChat supports a wide range of models, including +LLama 2, Vicuna, Alpaca, Baize, ChatGLM, Dolly, Falcon, FastChat-T5, GPT4ALL, Guanaco, MTP, OpenAssistant, OpenChat, RedPajama, StableLM, WizardLM, xDAN-AI and more. + +See a complete list of supported models and instructions to add a new model [here](docs/model_support.md). + +#### Single GPU +The command below requires around 14GB of GPU memory for Vicuna-7B and 28GB of GPU memory for Vicuna-13B. +See the ["Not Enough Memory" section](#not-enough-memory) below if you do not have enough memory. +`--model-path` can be a local folder or a Hugging Face repo name. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 +``` + +#### Multiple GPUs +You can use model parallelism to aggregate GPU memory from multiple GPUs on the same machine. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --num-gpus 2 +``` + +Tips: +Sometimes the "auto" device mapping strategy in huggingface/transformers does not perfectly balance the memory allocation across multiple GPUs. +You can use `--max-gpu-memory` to specify the maximum memory per GPU for storing model weights. +This allows it to allocate more memory for activations, so you can use longer context lengths or larger batch sizes. For example, + +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --num-gpus 2 --max-gpu-memory 8GiB +``` + +#### CPU Only +This runs on the CPU only and does not require GPU. It requires around 30GB of CPU memory for Vicuna-7B and around 60GB of CPU memory for Vicuna-13B. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device cpu +``` + +Use Intel AI Accelerator AVX512_BF16/AMX to accelerate CPU inference. +``` +CPU_ISA=amx python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device cpu +``` + +#### Metal Backend (Mac Computers with Apple Silicon or AMD GPUs) +Use `--device mps` to enable GPU acceleration on Mac computers (requires torch >= 2.0). +Use `--load-8bit` to turn on 8-bit compression. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device mps --load-8bit +``` +Vicuna-7B can run on a 32GB M1 Macbook with 1 - 2 words / second. + +#### Intel XPU (Intel Data Center and Arc A-Series GPUs) +Install the [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html). Set the OneAPI environment variables: +``` +source /opt/intel/oneapi/setvars.sh +``` + +Use `--device xpu` to enable XPU/GPU acceleration. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device xpu +``` +Vicuna-7B can run on an Intel Arc A770 16GB. + +#### Ascend NPU +Install the [Ascend PyTorch Adapter](https://github.com/Ascend/pytorch). Set the CANN environment variables: +``` +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +Use `--device npu` to enable NPU acceleration. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device npu +``` +Vicuna-7B/13B can run on an Ascend NPU. + +#### Not Enough Memory +If you do not have enough memory, you can enable 8-bit compression by adding `--load-8bit` to commands above. +This can reduce memory usage by around half with slightly degraded model quality. +It is compatible with the CPU, GPU, and Metal backend. + +Vicuna-13B with 8-bit compression can run on a single GPU with 16 GB of VRAM, like an Nvidia RTX 3090, RTX 4080, T4, V100 (16GB), or an AMD RX 6800 XT. + +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --load-8bit +``` + +In addition to that, you can add `--cpu-offloading` to commands above to offload weights that don't fit on your GPU onto the CPU memory. +This requires 8-bit compression to be enabled and the bitsandbytes package to be installed, which is only available on linux operating systems. + +#### More Platforms and Quantization +- For AMD GPU users, please install ROCm and [the ROCm version of PyTorch](https://pytorch.org/get-started/locally/) before you install FastChat. See also this [post](https://github.com/lm-sys/FastChat/issues/104#issuecomment-1613791563). +- FastChat supports ExLlama V2. See [docs/exllama_v2.md](/docs/exllama_v2.md). +- FastChat supports GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). See [docs/gptq.md](/docs/gptq.md). +- FastChat supports AWQ 4bit inference with [mit-han-lab/llm-awq](https://github.com/mit-han-lab/llm-awq). See [docs/awq.md](/docs/awq.md). +- [MLC LLM](https://mlc.ai/mlc-llm/), backed by [TVM Unity](https://github.com/apache/tvm/tree/unity) compiler, deploys Vicuna natively on phones, consumer-class GPUs and web browsers via Vulkan, Metal, CUDA and WebGPU. + +#### Use models from modelscope +For Chinese users, you can use models from www.modelscope.cn via specify the following environment variables. +```bash +export FASTCHAT_USE_MODELSCOPE=True +``` + +## Serving with Web GUI + + + +To serve using the web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the webserver and model workers. You can learn more about the architecture [here](docs/server_arch.md). + +Here are the commands to follow in your terminal: + +#### Launch the controller +```bash +python3 -m fastchat.serve.controller +``` + +This controller manages the distributed workers. + +#### Launch the model worker(s) +```bash +python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 +``` +Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller . + +To ensure that your model worker is connected to your controller properly, send a test message using the following command: +```bash +python3 -m fastchat.serve.test_message --model-name vicuna-7b-v1.5 +``` +You will see a short output. + +#### Launch the Gradio web server +```bash +python3 -m fastchat.serve.gradio_web_server +``` + +This is the user interface that users will interact with. + +By following these steps, you will be able to serve your models using the web UI. You can open your browser and chat with a model now. +If the models do not show up, try to reboot the gradio web server. + +#### (Optional): Advanced Features, Scalability, Third Party UI +- You can register multiple model workers to a single controller, which can be used for serving a single model with higher throughput or serving multiple models at the same time. When doing so, please allocate different GPUs and ports for different model workers. +``` +# worker 0 +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000 +# worker 1 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001 +``` +- You can also launch a multi-tab gradio server, which includes the Chatbot Arena tabs. +```bash +python3 -m fastchat.serve.gradio_web_server_multi +``` +- The default model worker based on huggingface/transformers has great compatibility but can be slow. If you want high-throughput batched serving, you can try [vLLM integration](docs/vllm_integration.md). +- If you want to host it on your own UI or third party UI, see [Third Party UI](docs/third_party_ui.md). + +## API +### OpenAI-Compatible RESTful APIs & SDK +FastChat provides OpenAI-compatible APIs for its supported models, so you can use FastChat as a local drop-in replacement for OpenAI APIs. +The FastChat server is compatible with both [openai-python](https://github.com/openai/openai-python) library and cURL commands. +The REST API is capable of being executed from Google Colab free tier, as demonstrated in the [FastChat_API_GoogleColab.ipynb](https://github.com/lm-sys/FastChat/blob/main/playground/FastChat_API_GoogleColab.ipynb) notebook, available in our repository. +See [docs/openai_api.md](docs/openai_api.md). + +### Hugging Face Generation APIs +See [fastchat/serve/huggingface_api.py](fastchat/serve/huggingface_api.py). + +### LangChain Integration +See [docs/langchain_integration](docs/langchain_integration.md). + +## Evaluation +We use MT-bench, a set of challenging multi-turn open-ended questions to evaluate models. +To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as judges and assess the quality of the models' responses. +See instructions for running MT-bench at [fastchat/llm_judge](fastchat/llm_judge). + +MT-bench is the new recommended way to benchmark your models. If you are still looking for the old 80 questions used in the vicuna blog post, please go to [vicuna-blog-eval](https://github.com/lm-sys/vicuna-blog-eval). + +## Fine-tuning +### Data + +Vicuna is created by fine-tuning a Llama base model using approximately 125K user-shared conversations gathered from ShareGPT.com with public APIs. To ensure data quality, we convert the HTML back to markdown and filter out some inappropriate or low-quality samples. Additionally, we divide lengthy conversations into smaller segments that fit the model's maximum context length. For detailed instructions to clean the ShareGPT data, check out [here](docs/commands/data_cleaning.md). + +We will not release the ShareGPT dataset. If you would like to try the fine-tuning code, you can run it with some dummy conversations in [dummy_conversation.json](data/dummy_conversation.json). You can follow the same format and plug in your own data. + +### Code and Hyperparameters +Our code is based on [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) with additional support for multi-turn conversations. +We use similar hyperparameters as the Stanford Alpaca. + +| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | +| --- | ---: | ---: | ---: | ---: | ---: | +| Vicuna-13B | 128 | 2e-5 | 3 | 2048 | 0 | + +### Fine-tuning Vicuna-7B with Local GPUs + +- Install dependency +```bash +pip3 install -e ".[train]" +``` + +- You can use the following command to train Vicuna-7B with 4 x A100 (40GB). Update `--model_name_or_path` with the actual path to Llama weights and `--data_path` with the actual path to data. +```bash +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --data_path data/dummy_conversation.json \ + --bf16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True +``` + +Tips: +- If you are using V100 which is not supported by FlashAttention, you can use the [memory-efficient attention](https://arxiv.org/abs/2112.05682) implemented in [xFormers](https://github.com/facebookresearch/xformers). Install xformers and replace `fastchat/train/train_mem.py` above with [fastchat/train/train_xformers.py](fastchat/train/train_xformers.py). +- If you meet out-of-memory due to "FSDP Warning: When using FSDP, it is efficient and recommended... ", see solutions [here](https://github.com/huggingface/transformers/issues/24724#issuecomment-1645189539). +- If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823). +- To turn on logging to popular experiment tracking tools such as Tensorboard, MLFlow or Weights & Biases, use the `report_to` argument, e.g. pass `--report_to wandb` to turn on logging to Weights & Biases. + +### Other models, platforms and LoRA support +More instructions to train other models (e.g., FastChat-T5) and use LoRA are in [docs/training.md](docs/training.md). + +### Fine-tuning on Any Cloud with SkyPilot +[SkyPilot](https://github.com/skypilot-org/skypilot) is a framework built by UC Berkeley for easily and cost effectively running ML workloads on any cloud (AWS, GCP, Azure, Lambda, etc.). +Find SkyPilot documentation [here](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna) on using managed spot instances to train Vicuna and save on your cloud costs. + +## Citation +The code (training, serving, and evaluation) in this repository is mostly developed for or derived from the paper below. +Please cite it if you find the repository helpful. + +``` +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` + +We are also planning to add more of our research to this repository. diff --git a/assets/demo_narrow.gif b/assets/demo_narrow.gif new file mode 100644 index 0000000000000000000000000000000000000000..9f96f43db6c64c37ac45b33616e6d50f664f6070 --- /dev/null +++ b/assets/demo_narrow.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e02d6a8fc6820f784105b3515b595730cc74542b4cf3f2a84a4361a0db17766 +size 12261733 diff --git a/assets/qa_browser.png b/assets/qa_browser.png new file mode 100644 index 0000000000000000000000000000000000000000..61ab4b0991d304a2a3c34a9191d4c98b2fae3e6e Binary files /dev/null and b/assets/qa_browser.png differ diff --git a/assets/screenshot_cli.png b/assets/screenshot_cli.png new file mode 100644 index 0000000000000000000000000000000000000000..7a7dd5d6b3456c4919a623bf0f1c00c539cb99f4 Binary files /dev/null and b/assets/screenshot_cli.png differ diff --git a/assets/screenshot_gui.png b/assets/screenshot_gui.png new file mode 100644 index 0000000000000000000000000000000000000000..ecb41d2f03f132ac882024cd4f61b082d0d8d6ac Binary files /dev/null and b/assets/screenshot_gui.png differ diff --git a/assets/server_arch.png b/assets/server_arch.png new file mode 100644 index 0000000000000000000000000000000000000000..06f8b2acf2bd64ff525f3065c393c19a5b3441df Binary files /dev/null and b/assets/server_arch.png differ diff --git a/assets/vicuna_logo.jpeg b/assets/vicuna_logo.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..e7883dc886b96d078883e01aefd16792133e204a Binary files /dev/null and b/assets/vicuna_logo.jpeg differ diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4fc41918bebddec7324b00702cd58ce3fad5cb7b --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,7 @@ +FROM nvidia/cuda:12.2.0-runtime-ubuntu20.04 + +RUN apt-get update -y && apt-get install -y python3.9 python3.9-distutils curl +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py +RUN python3.9 get-pip.py +RUN pip3 install fschat +RUN pip3 install fschat[model_worker,webui] \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..113e0c7a36511518307a5c299defc1d6dbb7fc50 --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,36 @@ +version: "3.9" + +services: + fastchat-controller: + build: + context: . + dockerfile: Dockerfile + image: fastchat:latest + ports: + - "21001:21001" + entrypoint: ["python3.9", "-m", "fastchat.serve.controller", "--host", "0.0.0.0", "--port", "21001"] + fastchat-model-worker: + build: + context: . + dockerfile: Dockerfile + volumes: + - huggingface:/root/.cache/huggingface + image: fastchat:latest + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + entrypoint: ["python3.9", "-m", "fastchat.serve.model_worker", "--model-names", "${FASTCHAT_WORKER_MODEL_NAMES:-vicuna-7b-v1.5}", "--model-path", "${FASTCHAT_WORKER_MODEL_PATH:-lmsys/vicuna-7b-v1.5}", "--worker-address", "http://fastchat-model-worker:21002", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "21002"] + fastchat-api-server: + build: + context: . + dockerfile: Dockerfile + image: fastchat:latest + ports: + - "8000:8000" + entrypoint: ["python3.9", "-m", "fastchat.serve.openai_api_server", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "8000"] +volumes: + huggingface: diff --git a/docs/arena.md b/docs/arena.md new file mode 100644 index 0000000000000000000000000000000000000000..b0d87aa2fc0eadd2a9d81db9fe5e785a6a981e1a --- /dev/null +++ b/docs/arena.md @@ -0,0 +1,48 @@ +# Chatbot Arena +Chatbot Arena is an LLM benchmark platform featuring anonymous, randomized battles, available at https://chat.lmsys.org. +We invite the entire community to join this benchmarking effort by contributing your votes and models. + +## How to add a new model +If you want to see a specific model in the arena, you can follow the methods below. + +### Method 1: Hosted by 3rd party API providers or yourself +If you have a model hosted by a 3rd party API provider or yourself, please give us the access to an API endpoint. + - We prefer OpenAI-compatible APIs, so we can reuse our [code](https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/api_provider.py) for calling OpenAI models. + - If you have your own API protocol, please follow the [instructions](model_support.md) to add them. Contribute your code by sending a pull request. + +### Method 2: Hosted by LMSYS +1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md). +2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors. + + +## How to launch vision arena + +1. Run `python3 -m fastchat.serve.controller` to start the controller and begin registering local model workers and API-provided workers. +2. Run `python3 -m fastchat.serve.sglang_worker --model-path --tokenizer-path ` to run local vision-language models. Currently supported models include the LLaVA and Yi-VL series. +3. If you are using a 3rd party model with an API provider (e.g. GPT-4-V, Gemini 1.5), please follow the instructions [model_support.md](model_support.md) to add a json file `api_endpoints.json`. +4. Run the gradio server with the `--vision-arena` flag on. +5. To run and store images into a remote directory, add the flag: `--use-remote-storage` +6. To run and allow samples of random questions, add `--random_questions metadata_sampled.json`. Check sections below for how to generate this. + +Example command: +``` +python3 -m fastchat.serve.gradio_web_server_multi --share --register-api-endpoint-file api_endpoints.json --vision-arena --use-remote-storage --random-questions metadata_sampled.json +``` + +### NSFW and CSAM Detection +1. Adding NSFW Endpoint and API key: Please add the following environment variables to run the NSFW moderation filter for images: + - `AZURE_IMG_MODERATION_ENDPOINT`: This is the endpoint that the NSFW moderator is hosted (e.g. https://{endpoint}/contentmoderator/moderate/v1.0/ProcessImage/Evaluate). Change the `endpoint` to your own. + - `AZURE_IMG_MODERATION_API_KEY`: Your API key to run this endpoint. +2. Adding CSAM API key: + - `PHOTODNA_API_KEY`: The API key that runs the CSAM detector endpoint. + +Example in `~/.bashrc`: +``` +export AZURE_IMG_MODERATION_ENDPOINT=https:///contentmoderator/moderate/v1.0/ProcessImage/Evaluate +export AZURE_IMG_MODERATION_API_KEY= +export PHOTODNA_API_KEY= +``` + +### Adding Random Samples for VQA +We provide random samples of example images for users to interact with coming from various datasets including DocVQA, RealWorldQA, ChartQA and VizWiz-VQA. +1. Download the images and generate random questions file by running `python fastchat/serve/vision/create_vqa_examples_dir.py` \ No newline at end of file diff --git a/docs/awq.md b/docs/awq.md new file mode 100644 index 0000000000000000000000000000000000000000..ac38ef3dd513c4346a1b8fc05a97ceda749030b6 --- /dev/null +++ b/docs/awq.md @@ -0,0 +1,71 @@ +# AWQ 4bit Inference + +We integrated [AWQ](https://github.com/mit-han-lab/llm-awq) into FastChat to provide **efficient and accurate** 4bit LLM inference. + +## Install AWQ + +Setup environment (please refer to [this link](https://github.com/mit-han-lab/llm-awq#install) for more details): +```bash +conda create -n fastchat-awq python=3.10 -y +conda activate fastchat-awq +# cd /path/to/FastChat +pip install --upgrade pip # enable PEP 660 support +pip install -e . # install fastchat + +git clone https://github.com/mit-han-lab/llm-awq repositories/llm-awq +cd repositories/llm-awq +pip install -e . # install awq package + +cd awq/kernels +python setup.py install # install awq CUDA kernels +``` + +## Chat with the CLI + +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/mit-han-lab/vicuna-7b-v1.3-4bit-g128-awq + +# You can specify which quantized model to use by setting --awq-ckpt +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7b-v1.3-4bit-g128-awq \ + --awq-wbits 4 \ + --awq-groupsize 128 +``` + +## Benchmark + +* Through **4-bit weight quantization**, AWQ helps to run larger language models within the device memory restriction and prominently accelerates token generation. All benchmarks are done with group_size 128. + +* Benchmark on NVIDIA RTX A6000: + + | Model | Bits | Max Memory (MiB) | Speed (ms/token) | AWQ Speedup | + | --------------- | ---- | ---------------- | ---------------- | ----------- | + | vicuna-7b | 16 | 13543 | 26.06 | / | + | vicuna-7b | 4 | 5547 | 12.43 | 2.1x | + | llama2-7b-chat | 16 | 13543 | 27.14 | / | + | llama2-7b-chat | 4 | 5547 | 12.44 | 2.2x | + | vicuna-13b | 16 | 25647 | 44.91 | / | + | vicuna-13b | 4 | 9355 | 17.30 | 2.6x | + | llama2-13b-chat | 16 | 25647 | 47.28 | / | + | llama2-13b-chat | 4 | 9355 | 20.28 | 2.3x | + +* NVIDIA RTX 4090: + + | Model | AWQ 4bit Speed (ms/token) | FP16 Speed (ms/token) | AWQ Speedup | + | --------------- | ------------------------- | --------------------- | ----------- | + | vicuna-7b | 8.61 | 19.09 | 2.2x | + | llama2-7b-chat | 8.66 | 19.97 | 2.3x | + | vicuna-13b | 12.17 | OOM | / | + | llama2-13b-chat | 13.54 | OOM | / | + +* NVIDIA Jetson Orin: + + | Model | AWQ 4bit Speed (ms/token) | FP16 Speed (ms/token) | AWQ Speedup | + | --------------- | ------------------------- | --------------------- | ----------- | + | vicuna-7b | 65.34 | 93.12 | 1.4x | + | llama2-7b-chat | 75.11 | 104.71 | 1.4x | + | vicuna-13b | 115.40 | OOM | / | + | llama2-13b-chat | 136.81 | OOM | / | diff --git a/docs/commands/conv_release.md b/docs/commands/conv_release.md new file mode 100644 index 0000000000000000000000000000000000000000..64cfadbc1dd209f225c67ab228e641429957a2b1 --- /dev/null +++ b/docs/commands/conv_release.md @@ -0,0 +1,38 @@ +## Chatbot Arena Conversations + +1. Gather battles +``` +python3 clean_battle_data.py --max-num 10 --mode conv_release +``` + +2. Tag OpenAI moderation +``` +python3 tag_openai_moderation.py --in clean_battle_conv_20230814.json +``` + +3. Clean PII + +4. Filter additional blocked words + +``` +python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json +``` + +5. Add additional toxicity tag + + +## All Conversations + +1. Gather chats +``` +python3 clean_chat_data.py +``` + +2. Sample +``` +python3 conv_release_scripts/sample.py +``` + + +## Prompt distribution + diff --git a/docs/commands/data_cleaning.md b/docs/commands/data_cleaning.md new file mode 100644 index 0000000000000000000000000000000000000000..410ce8a828c38b8ecca98aa40469d9e8b93b3580 --- /dev/null +++ b/docs/commands/data_cleaning.md @@ -0,0 +1,19 @@ +## Data cleaning + +## Requirements +``` +pip3 install bs4 markdownify +pip3 install polyglot pyicu pycld2 +``` + +## Steps +``` +# Convert html to markdown +python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json + +# Keep or remove specific languages +python3 -m fastchat.data.optional_clean --in sharegpt_clean.json --out sharegpt_clean_lang.json --skip-lang SOME_LANGUAGE_CODE + +# Split long conversations +python3 -m fastchat.data.split_long_conversation --in sharegpt_clean_lang.json --out sharegpt_clean_lang_split.json --model-name /home/ubuntu/model_weights/llama-7b/ +``` diff --git a/docs/commands/leaderboard.md b/docs/commands/leaderboard.md new file mode 100644 index 0000000000000000000000000000000000000000..04477a0959c97d7d11160400ee3d4cc3c1129681 --- /dev/null +++ b/docs/commands/leaderboard.md @@ -0,0 +1,37 @@ +### Get logs +``` +gsutil -m rsync -r gs://fastchat_logs ~/fastchat_logs/ +``` + +### Clean battle data +``` +cd ~/FastChat/fastchat/serve/monitor +python3 clean_battle_data.py +``` + +### Run Elo analysis +``` +python3 elo_analysis.py --clean-battle-file clean_battle_20230905.json +``` + +### Copy files to HF space +1. update plots +``` +scp atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/elo_results_20230905.pkl . +``` + +2. update table +``` +wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/raw/main/leaderboard_table_20230905.csv +``` + +### Update files on webserver +``` +DATE=20231002 + +rm -rf elo_results.pkl leaderboard_table.csv +wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/resolve/main/elo_results_$DATE.pkl +wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/resolve/main/leaderboard_table_$DATE.csv +ln -s leaderboard_table_$DATE.csv leaderboard_table.csv +ln -s elo_results_$DATE.pkl elo_results.pkl +``` diff --git a/docs/commands/local_cluster.md b/docs/commands/local_cluster.md new file mode 100644 index 0000000000000000000000000000000000000000..bc1588c0ffa89758bd94dcd345838b8f15d88655 --- /dev/null +++ b/docs/commands/local_cluster.md @@ -0,0 +1,38 @@ +### Local GPU cluster +node-01 +``` +python3 -m fastchat.serve.controller --host 0.0.0.0 --port 10002 + +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-13b-v1.5 --model-name vicuna-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-13b-v1.5 --model-name vicuna-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker-address http://$(hostname):31001 + +CUDA_VISIBLE_DEVICES=2,3 ray start --head +python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-33b-v1.3 --model-name vicuna-33b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --num-gpus 2 +``` + +node-02 +``` +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-13b-chat-hf --model-name llama-2-13b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 --tokenizer meta-llama/Llama-2-7b-chat-hf +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-13b-chat-hf --model-name llama-2-13b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker-address http://$(hostname):31001 --tokenizer meta-llama/Llama-2-7b-chat-hf +CUDA_VISIBLE_DEVICES=2 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-7b-chat-hf --model-name llama-2-7b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --tokenizer meta-llama/Llama-2-7b-chat-hf +CUDA_VISIBLE_DEVICES=3 python3 -m fastchat.serve.vllm_worker --model-path WizardLM/WizardLM-13B-V1.1 --model-name wizardlm-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker-address http://$(hostname):31003 +``` + +node-03 +``` +python3 -m fastchat.serve.vllm_worker --model-path mosaicml/mpt-30b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 --num-gpus 2 +python3 -m fastchat.serve.vllm_worker --model-path timdettmers/guanaco-33b-merged --model-name guanaco-33b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --num-gpus 2 --tokenizer hf-internal-testing/llama-tokenizer +``` + +node-04 +``` +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.multi_model_worker --model-path ~/model_weights/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth --model-name RWKV-4-Raven-14B --model-path lmsys/fastchat-t5-3b-v1.0 --model-name fastchat-t5-3b --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker http://$(hostname):31000 --limit 4 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.multi_model_worker --model-path OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5 --model-name oasst-pythia-12b --model-path mosaicml/mpt-7b-chat --model-name mpt-7b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker http://$(hostname):31001 --limit 4 +CUDA_VISIBLE_DEVICES=2 python3 -m fastchat.serve.multi_model_worker --model-path lmsys/vicuna-7b-v1.5 --model-name vicuna-7b --model-path THUDM/chatglm-6b --model-name chatglm-6b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker http://$(hostname):31002 --limit 4 +CUDA_VISIBLE_DEVICES=3 python3 -m fastchat.serve.vllm_worker --model-path ~/model_weights/alpaca-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker-address http://$(hostname):31003 +``` + +test +``` +python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:10002 +``` diff --git a/docs/commands/pypi.md b/docs/commands/pypi.md new file mode 100644 index 0000000000000000000000000000000000000000..5b53dae6b9b22d883f03e707771482947d56ee02 --- /dev/null +++ b/docs/commands/pypi.md @@ -0,0 +1,11 @@ +### Requirement +``` +python3 -m pip install twine +python3 -m pip install --upgrade pip +pip3 install build +``` + +### Upload +``` +bash scripts/upload_pypi.sh +``` diff --git a/docs/commands/webserver.md b/docs/commands/webserver.md new file mode 100644 index 0000000000000000000000000000000000000000..df96cf8d2985026bf41779f76f1376bfe92b3040 --- /dev/null +++ b/docs/commands/webserver.md @@ -0,0 +1,94 @@ +### Install +``` +sudo apt update +sudo apt install tmux htop + +wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh +bash Anaconda3-2022.10-Linux-x86_64.sh + +conda create -n fastchat python=3.9 +conda activate fastchat + +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +pip3 install -e . +``` + + +### Launch servers +``` +cd fastchat_logs/controller +python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001 +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name https:// +python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:21001 + +cd fastchat_logs/server0 + +python3 -m fastchat.serve.huggingface_api_worker --model-info-file ~/elo_results/register_hf_api_models.json + +export OPENAI_API_KEY= +export ANTHROPIC_API_KEY= +export GCP_PROJECT_ID= + +python3 -m fastchat.serve.gradio_web_server_multi --controller http://localhost:21001 --concurrency 50 --add-chatgpt --add-claude --add-palm --elo ~/elo_results/elo_results.pkl --leaderboard-table-file ~/elo_results/leaderboard_table.csv --register ~/elo_results/register_oai_models.json --show-terms + +python3 backup_logs.py +``` + + +### Check the launch time +``` +for i in $(seq 0 11); do cat fastchat_logs/server$i/gradio_web_server.log | grep "Running on local URL" | tail -n 1; done +``` + + +### Increase the limit of max open files +One process (do not need reboot) +``` +sudo prlimit --nofile=1048576:1048576 --pid=$id + +for id in $(ps -ef | grep gradio_web_server | awk '{print $2}'); do echo $id; prlimit --nofile=1048576:1048576 --pid=$id; done +``` + +System (need reboot): Add the lines below to `/etc/security/limits.conf` +``` +* hard nofile 65535 +* soft nofile 65535 +``` + + +### Gradio edit (3.35.2) +1. gtag and canvas +``` +vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/templates/frontend/index.html +``` + +``` + + + +``` + +2. deprecation warnings +``` +vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/deprecation.py +``` + +``` +def check_deprecated_parameters( +``` + +3. Loading +``` +vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/templates/frontend/assets/index-188ef5e8.js +``` + +``` +%s/"Loading..."/"Loading...(Please refresh if it takes more than 30 seconds)"/g +``` diff --git a/docs/dashinfer_integration.md b/docs/dashinfer_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..9b2aff22e84a43e583bc8e8e23d97df690eacb28 --- /dev/null +++ b/docs/dashinfer_integration.md @@ -0,0 +1,23 @@ +# dash-infer Integration +[DashInfer](https://github.com/modelscope/dash-infer) is a high-performance inference engine specifically optimized for CPU environments, delivering exceptional performance boosts for LLM inference tasks. It supports acceleration for a variety of models including Llama, Qwen, and ChatGLM, making it a versatile choice as a performant worker in FastChat. Notably, DashInfer exhibits significant performance enhancements on both Intel x64 and ARMv9 processors, catering to a wide spectrum of hardware platforms. Its efficient design and optimization techniques ensure rapid and accurate inference capabilities, making it an ideal solution for deploying large language models in resource-constrained environments or scenarios where CPU utilization is preferred over GPU acceleration. + +## Instructions +1. Install dash-infer. + ``` + pip install dashinfer + ``` + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the dash-infer worker (`fastchat.serve.dashinfer_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. + ``` + python3 -m fastchat.serve.dashinfer_worker --model-path qwen/Qwen-7B-Chat --revision=master /path/to/dashinfer-model-generation-config.json + ``` +Here is an example: + ``` + python3 -m fastchat.serve.dashinfer_worker --model-path qwen/Qwen-7B-Chat --revision=master dash-infer/examples/python/model_config/config_qwen_v10_7b.json + ``` + + If you use an already downloaded model, try to replace model-path with a local one and choose a conversation template via --conv-template option + ''' + python3 -m fastchat.serve.dashinfer_worker --model-path ~/.cache/modelscope/hub/qwen/Qwen-7B-Chat --conv-template qwen-7b-chat /path/to/dashinfer-model-generation-config.json + ''' + All avaliable conversation chat templates are listed at [fastchat/conversation.py](../fastchat/conversation.py) diff --git a/docs/dataset_release.md b/docs/dataset_release.md new file mode 100644 index 0000000000000000000000000000000000000000..add2c890992e182215d23c9d75d519eefa1fc8c6 --- /dev/null +++ b/docs/dataset_release.md @@ -0,0 +1,6 @@ +## Datasets +We release the following datasets based on our projects and websites. + +- [LMSYS-Chat-1M: A Large-Scale Real-World LLM Conversation Dataset](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations) +- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments) diff --git a/docs/exllama_v2.md b/docs/exllama_v2.md new file mode 100644 index 0000000000000000000000000000000000000000..c9fc72bfbbe24795e378113214b1a57c7de2b4c0 --- /dev/null +++ b/docs/exllama_v2.md @@ -0,0 +1,63 @@ +# ExllamaV2 GPTQ Inference Framework + +Integrated [ExllamaV2](https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed. + +**Note: Exllama not yet support embedding REST API.** + +## Install ExllamaV2 + +Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details): + +```bash +git clone https://github.com/turboderp/exllamav2 +cd exllamav2 +pip install -e . +``` + +Chat with the CLI: +```bash +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama +``` + +Start model worker: +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g + +# Load model with default configuration (max sequence length 4096, no GPU split setting). +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama + +#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1. +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama \ + --exllama-max-seq-len 2048 \ + --exllama-gpu-split 18,24 +``` + +`--exllama-cache-8bit` can be used to enable 8-bit caching with exllama and save some VRAM. + +## Performance + +Reference: https://github.com/turboderp/exllamav2#performance + + +| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 | +|------------|--------------|-------|-------|-----|------------|----------|------------|-------------| +| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s | +| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s | +| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s | +| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s | +| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s | +| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s | +| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s | +| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s | +| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s | +| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s | +| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s | diff --git a/docs/gptq.md b/docs/gptq.md new file mode 100644 index 0000000000000000000000000000000000000000..4078d1e0dc7498885f55448edaac41258e68980c --- /dev/null +++ b/docs/gptq.md @@ -0,0 +1,59 @@ +# GPTQ 4bit Inference + +Support GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). + +1. Window user: use the `old-cuda` branch. +2. Linux user: recommend the `fastest-inference-4bit` branch. + +## Install + +Setup environment: +```bash +# cd /path/to/FastChat +git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git repositories/GPTQ-for-LLaMa +cd repositories/GPTQ-for-LLaMa +# Window's user should use the `old-cuda` branch +git switch fastest-inference-4bit +# Install `quant-cuda` package in FastChat's virtualenv +python3 setup_cuda.py install +pip3 install texttable +``` + +Chat with the CLI: +```bash +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --gptq-wbits 4 \ + --gptq-groupsize 128 +``` + +Start model worker: +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g + +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --gptq-wbits 4 \ + --gptq-groupsize 128 + +# You can specify which quantized model to use +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --gptq-ckpt models/vicuna-7B-1.1-GPTQ-4bit-128g/vicuna-7B-1.1-GPTQ-4bit-128g.safetensors \ + --gptq-wbits 4 \ + --gptq-groupsize 128 \ + --gptq-act-order +``` + +## Benchmark + +| LLaMA-13B | branch | Bits | group-size | memory(MiB) | PPL(c4) | Median(s/token) | act-order | speed up | +| --------- | ---------------------- | ---- | ---------- | ----------- | ------- | --------------- | --------- | -------- | +| FP16 | fastest-inference-4bit | 16 | - | 26634 | 6.96 | 0.0383 | - | 1x | +| GPTQ | triton | 4 | 128 | 8590 | 6.97 | 0.0551 | - | 0.69x | +| GPTQ | fastest-inference-4bit | 4 | 128 | 8699 | 6.97 | 0.0429 | true | 0.89x | +| GPTQ | fastest-inference-4bit | 4 | 128 | 8699 | 7.03 | 0.0287 | false | 1.33x | +| GPTQ | fastest-inference-4bit | 4 | -1 | 8448 | 7.12 | 0.0284 | false | 1.44x | diff --git a/docs/langchain_integration.md b/docs/langchain_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..50174a85e654768d7e98825586b4a79b457bf71a --- /dev/null +++ b/docs/langchain_integration.md @@ -0,0 +1,90 @@ +# Local LangChain with FastChat + +[LangChain](https://python.langchain.com/en/latest/index.html) is a library that facilitates the development of applications by leveraging large language models (LLMs) and enabling their composition with other sources of computation or knowledge. +FastChat's OpenAI-compatible [API server](openai_api.md) enables using LangChain with open models seamlessly. + +## Launch RESTful API Server + +Here are the steps to launch a local OpenAI API server for LangChain. + +First, launch the controller + +```bash +python3 -m fastchat.serve.controller +``` + +LangChain uses OpenAI model names by default, so we need to assign some faux OpenAI model names to our local model. +Here, we use Vicuna as an example and use it for three endpoints: chat completion, completion, and embedding. +`--model-path` can be a local folder or a Hugging Face repo name. +See a full list of supported models [here](../README.md#supported-models). + +```bash +python3 -m fastchat.serve.model_worker --model-names "gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002" --model-path lmsys/vicuna-7b-v1.5 +``` + +Finally, launch the RESTful API server + +```bash +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` + +## Set OpenAI Environment + +You can set your environment with the following commands. + +Set OpenAI base url + +```bash +export OPENAI_API_BASE=http://localhost:8000/v1 +``` + +Set OpenAI API key + +```bash +export OPENAI_API_KEY=EMPTY +``` + +If you meet the following OOM error while creating embeddings, please set a smaller batch size by using environment variables. + +~~~bash +openai.error.APIError: Invalid response object from API: '{"object":"error","message":"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**\\n\\n(CUDA out of memory. Tried to allocate xxx MiB (GPU 0; xxx GiB total capacity; xxx GiB already allocated; xxx MiB free; xxx GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF)","code":50002}' (HTTP response code was 400) +~~~ + +You can try `export FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE=1`. + +## Try local LangChain + +Here is a question answerting example. + +Download a text file. + +```bash +wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt +``` + +Run LangChain. + +~~~py +from langchain.chat_models import ChatOpenAI +from langchain.document_loaders import TextLoader +from langchain.embeddings import OpenAIEmbeddings +from langchain.indexes import VectorstoreIndexCreator + +embedding = OpenAIEmbeddings(model="text-embedding-ada-002") +loader = TextLoader("state_of_the_union.txt") +index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) +llm = ChatOpenAI(model="gpt-3.5-turbo") + +questions = [ + "Who is the speaker", + "What did the president say about Ketanji Brown Jackson", + "What are the threats to America", + "Who are mentioned in the speech", + "Who is the vice president", + "How many projects were announced", +] + +for query in questions: + print("Query:", query) + print("Answer:", index.query(query, llm=llm)) +~~~ diff --git a/docs/lightllm_integration.md b/docs/lightllm_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..b271a826a19927c3ab8ead8231308bfd4b65aa25 --- /dev/null +++ b/docs/lightllm_integration.md @@ -0,0 +1,18 @@ +# LightLLM Integration +You can use [LightLLM](https://github.com/ModelTC/lightllm) as an optimized worker implementation in FastChat. +It offers advanced continuous batching and a much higher (~10x) throughput. +See the supported models [here](https://github.com/ModelTC/lightllm?tab=readme-ov-file#supported-model-list). + +## Instructions +1. Please refer to the [Get started](https://github.com/ModelTC/lightllm?tab=readme-ov-file#get-started) to install LightLLM. Or use [Pre-built image](https://github.com/ModelTC/lightllm?tab=readme-ov-file#container) + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the LightLLM worker (`fastchat.serve.lightllm_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. Refer to [--max_total_token_num](https://github.com/ModelTC/lightllm/blob/4a9824b6b248f4561584b8a48ae126a0c8f5b000/docs/ApiServerArgs.md?plain=1#L23) to understand how to calculate the `--max_total_token_num` argument. + ``` + python3 -m fastchat.serve.lightllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer_mode "auto" --max_total_token_num 154000 + ``` + + If you what to use quantized weight and kv cache for inference, try + + ``` + python3 -m fastchat.serve.lightllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer_mode "auto" --max_total_token_num 154000 --mode triton_int8weight triton_int8kv + ``` diff --git a/docs/mlx_integration.md b/docs/mlx_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..21642d94810ea4fd81d53cd03e158fdac2cece80 --- /dev/null +++ b/docs/mlx_integration.md @@ -0,0 +1,23 @@ +# Apple MLX Integration + +You can use [Apple MLX](https://github.com/ml-explore/mlx) as an optimized worker implementation in FastChat. + +It runs models efficiently on Apple Silicon + +See the supported models [here](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models). + +Note that for Apple Silicon Macs with less memory, smaller models (or quantized models) are recommended. + +## Instructions + +1. Install MLX. + + ``` + pip install "mlx-lm>=0.0.6" + ``` + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the MLX worker (`fastchat.serve.mlx_worker`). Remember to launch a model worker after you have launched the controller ([instructions](../README.md)) + + ``` + python3 -m fastchat.serve.mlx_worker --model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 + ``` diff --git a/docs/model_support.md b/docs/model_support.md new file mode 100644 index 0000000000000000000000000000000000000000..ba9acf5b1c3977c5b14529299c80d161028049e3 --- /dev/null +++ b/docs/model_support.md @@ -0,0 +1,140 @@ +# Model Support +This document describes how to support a new model in FastChat. + +## Content +- [Local Models](#local-models) +- [API-Based Models](#api-based-models) + +## Local Models +To support a new local model in FastChat, you need to correctly handle its prompt template and model loading. +The goal is to make the following command run with the correct prompts. + +``` +python3 -m fastchat.serve.cli --model [YOUR_MODEL_PATH] +``` + +You can run this example command to learn the code logic. + +``` +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +``` + +You can add `--debug` to see the actual prompt sent to the model. + +### Steps + +FastChat uses the `Conversation` class to handle prompt templates and `BaseModelAdapter` class to handle model loading. + +1. Implement a conversation template for the new model at [fastchat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py). You can follow existing examples and use `register_conv_template` to add a new one. Please also add a link to the official reference code if possible. +2. Implement a model adapter for the new model at [fastchat/model/model_adapter.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py). You can follow existing examples and use `register_model_adapter` to add a new one. +3. (Optional) add the model name to the "Supported models" [section](#supported-models) above and add more information in [fastchat/model/model_registry.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_registry.py). + +After these steps, the new model should be compatible with most FastChat features, such as CLI, web UI, model worker, and OpenAI-compatible API server. Please do some testing with these features as well. + +### Supported models + +- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) + - example: `python3 -m fastchat.serve.cli --model-path meta-llama/Llama-2-7b-chat-hf` +- Vicuna, Alpaca, LLaMA, Koala + - example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` +- [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) +- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) +- [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) +- [BAAI/AquilaChat2-34B](https://huggingface.co/BAAI/AquilaChat2-34B) +- [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en#using-huggingface-transformers) +- [argilla/notus-7b-v1](https://huggingface.co/argilla/notus-7b-v1) +- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) +- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven) + - example: `python3 -m fastchat.serve.cli --model-path ~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth` +- [bofenghuang/vigogne-2-7b-instruct](https://huggingface.co/bofenghuang/vigogne-2-7b-instruct) +- [bofenghuang/vigogne-2-7b-chat](https://huggingface.co/bofenghuang/vigogne-2-7b-chat) +- [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data) +- [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) +- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b) +- [deepseek-ai/deepseek-llm-67b-chat](https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat) +- [deepseek-ai/deepseek-coder-33b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct) +- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) +- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b) +- [FreedomIntelligence/ReaLM-7b-v1](https://huggingface.co/FreedomIntelligence/Realm-7b) +- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b) +- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta) +- [HuggingFaceH4/zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) +- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) +- [cllm/consistency-llm-7b-codesearchnet/consistency-llm-7b-gsm8k/consistency-llm-7b-sharegpt48k/consistency-llm-7b-spider](https://huggingface.co/cllm) +- [IEITYuan/Yuan2-2B/51B/102B-hf](https://huggingface.co/IEITYuan) +- [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat) +- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5) +- [meta-math/MetaMath-7B-V1.0](https://huggingface.co/meta-math/MetaMath-7B-V1.0) +- [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b) +- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) + - example: `python3 -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat` +- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT) +- [nomic-ai/gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy) +- [NousResearch/Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b) +- [openaccess-ai-collective/manticore-13b-chat-pyg](https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg) +- [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5) +- [openchat/openchat_3.5](https://huggingface.co/openchat/openchat_3.5) +- [Open-Orca/Mistral-7B-OpenOrca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca) +- [OpenLemur/lemur-70b-chat-v1](https://huggingface.co/OpenLemur/lemur-70b-chat-v1) +- [Phind/Phind-CodeLlama-34B-v2](https://huggingface.co/Phind/Phind-CodeLlama-34B-v2) +- [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b) +- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) +- [rishiraj/CatPPT](https://huggingface.co/rishiraj/CatPPT) +- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b) +- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) +- [tenyx/TenyxChat-7B-v1](https://huggingface.co/tenyx/TenyxChat-7B-v1) +- [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) +- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) +- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) +- [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b) +- [tiiuae/falcon-180B-chat](https://huggingface.co/tiiuae/falcon-180B-chat) +- [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged) +- [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat) +- [VMware/open-llama-7b-v2-open-instruct](https://huggingface.co/VMware/open-llama-7b-v2-open-instruct) +- [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0) +- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0) +- [Xwin-LM/Xwin-LM-7B-V0.1](https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1) +- Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b) +- Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a + model above. To activate, must have `peft` in the model path. Note: If + loading multiple peft models, you can have them share the base model weights by + setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model + worker. + + +## API-Based Models +To support an API-based model, consider learning from the existing OpenAI example. +If the model is compatible with OpenAI APIs, then a configuration file is all that's needed without any additional code. +For custom protocols, implementation of a streaming generator in [fastchat/serve/api_provider.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/api_provider.py) is required, following the provided examples. Currently, FastChat is compatible with OpenAI, Anthropic, Google Vertex AI, Mistral, Nvidia NGC, YandexGPT and Reka. + +### Steps to Launch a WebUI with an API Model +1. Specify the endpoint information in a JSON configuration file. For instance, create a file named `api_endpoints.json`: +```json +{ + "gpt-3.5-turbo": { + "model_name": "gpt-3.5-turbo", + "api_type": "openai", + "api_base": "https://api.openai.com/v1", + "api_key": "sk-******", + "anony_only": false, + "recommended_config": { + "temperature": 0.7, + "top_p": 1.0 + }, + "text-arena": true, + "vision-arena": false, + } +} +``` + - "api_type" can be one of the following: openai, anthropic, gemini, mistral, yandexgpt or reka. For custom APIs, add a new type and implement it accordingly. + - "anony_only" indicates whether to display this model in anonymous mode only. + - "recommended_config" indicates the recommended generation parameters for temperature and top_p. + - "text-arena" indicates whether the model should be displayed in the Text Arena. + - "vision-arena" indicates whether the model should be displayed in the Vision Arena. + +2. Launch the Gradio web server with the argument `--register api_endpoints.json`: +``` +python3 -m fastchat.serve.gradio_web_server --controller "" --share --register api_endpoints.json +``` + +Now, you can open a browser and interact with the model. diff --git a/docs/openai_api.md b/docs/openai_api.md new file mode 100644 index 0000000000000000000000000000000000000000..089b500ffc1f0bc6bfe8b023cdb147753a20068e --- /dev/null +++ b/docs/openai_api.md @@ -0,0 +1,152 @@ +# OpenAI-Compatible RESTful APIs + +FastChat provides OpenAI-compatible APIs for its supported models, so you can use FastChat as a local drop-in replacement for OpenAI APIs. +The FastChat server is compatible with both [openai-python](https://github.com/openai/openai-python) library and cURL commands. + +The following OpenAI APIs are supported: +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +The REST API can be seamlessly operated from Google Colab, as demonstrated in the [FastChat_API_GoogleColab.ipynb](https://github.com/lm-sys/FastChat/blob/main/playground/FastChat_API_GoogleColab.ipynb) notebook, available in our repository. This notebook provides a practical example of how to utilize the API effectively within the Google Colab environment. + +## RESTful API Server +First, launch the controller + +```bash +python3 -m fastchat.serve.controller +``` + +Then, launch the model worker(s) + +```bash +python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 +``` + +Finally, launch the RESTful API server + +```bash +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` + +Now, let us test the API server. + +### OpenAI Official SDK +The goal of `openai_api_server.py` is to implement a fully OpenAI-compatible API server, so the models can be used directly with [openai-python](https://github.com/openai/openai-python) library. + +First, install OpenAI python package >= 1.0: +```bash +pip install --upgrade openai +``` + +Then, interact with the Vicuna model: +```python +import openai + +openai.api_key = "EMPTY" +openai.base_url = "http://localhost:8000/v1/" + +model = "vicuna-7b-v1.5" +prompt = "Once upon a time" + +# create a completion +completion = openai.completions.create(model=model, prompt=prompt, max_tokens=64) +# print the completion +print(prompt + completion.choices[0].text) + +# create a chat completion +completion = openai.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Hello! What is your name?"}] +) +# print the completion +print(completion.choices[0].message.content) +``` + +Streaming is also supported. See [test_openai_api.py](../tests/test_openai_api.py). If your api server is behind a proxy you'll need to turn off buffering, you can do so in Nginx by setting `proxy_buffering off;` in the location block for the proxy. + +### cURL +cURL is another good tool for observing the output of the api. + +List Models: +```bash +curl http://localhost:8000/v1/models +``` + +Chat Completions: +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "messages": [{"role": "user", "content": "Hello! What is your name?"}] + }' +``` + +Text Completions: +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "prompt": "Once upon a time", + "max_tokens": 41, + "temperature": 0.5 + }' +``` + +Embeddings: +```bash +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "input": "Hello world!" + }' +``` + +### Running multiple + +If you want to run multiple models on the same machine and in the same process, +you can replace the `model_worker` step above with a multi model variant: + +```bash +python3 -m fastchat.serve.multi_model_worker \ + --model-path lmsys/vicuna-7b-v1.5 \ + --model-names vicuna-7b-v1.5 \ + --model-path lmsys/longchat-7b-16k \ + --model-names longchat-7b-16k +``` + +This loads both models into the same accelerator and in the same process. This +works best when using a Peft model that triggers the `PeftModelAdapter`. + +TODO: Base model weight optimization will be fixed once [this +Peft](https://github.com/huggingface/peft/issues/430) issue is resolved. + +## LangChain Support +This OpenAI-compatible API server supports LangChain. See [LangChain Integration](langchain_integration.md) for details. + +## Adjusting Environment Variables + +### Timeout +By default, a timeout error will occur if a model worker does not response within 100 seconds. If your model/hardware is slower, you can change this timeout through an environment variable: + +```bash +export FASTCHAT_WORKER_API_TIMEOUT= +``` + +### Batch size +If you meet the following OOM error while creating embeddings. You can use a smaller batch size by setting + +```bash +export FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE=1 +``` + +## Todos +Some features to be implemented: + +- [ ] Support more parameters like `logprobs`, `logit_bias`, `user`, `presence_penalty` and `frequency_penalty` +- [ ] Model details (permissions, owner and create time) +- [ ] Edits API +- [ ] Rate Limitation Settings diff --git a/docs/server_arch.md b/docs/server_arch.md new file mode 100644 index 0000000000000000000000000000000000000000..1ccc8a1623569d8612858cfaada2704875e46a49 --- /dev/null +++ b/docs/server_arch.md @@ -0,0 +1,2 @@ +# FastChat Server Architecture +![server arch](../assets/server_arch.png) diff --git a/docs/third_party_ui.md b/docs/third_party_ui.md new file mode 100644 index 0000000000000000000000000000000000000000..c0b2301500455965d46641c8d80ef2d7e54224e2 --- /dev/null +++ b/docs/third_party_ui.md @@ -0,0 +1,24 @@ +# Third Party UI +If you want to host it on your own UI or third party UI, you can launch the [OpenAI compatible server](openai_api.md) and host with a tunnelling service such as Tunnelmole or ngrok, and then enter the credentials appropriately. + +You can find suitable UIs from third party repos: +- [WongSaang's ChatGPT UI](https://github.com/WongSaang/chatgpt-ui) +- [McKayWrigley's Chatbot UI](https://github.com/mckaywrigley/chatbot-ui) + +- Please note that some third-party providers only offer the standard `gpt-3.5-turbo`, `gpt-4`, etc., so you will have to add your own custom model inside the code. [Here is an example of how to create a UI with any custom model name](https://github.com/ztjhz/BetterChatGPT/pull/461). + +##### Using Tunnelmole +Tunnelmole is an open source tunnelling tool. You can find its source code on [Github](https://github.com/robbie-cahill/tunnelmole-client). Here's how you can use Tunnelmole: +1. Install Tunnelmole with `curl -O https://install.tunnelmole.com/9Wtxu/install && sudo bash install`. (On Windows, download [tmole.exe](https://tunnelmole.com/downloads/tmole.exe)). Head over to the [README](https://github.com/robbie-cahill/tunnelmole-client) for other methods such as `npm` or building from source. +2. Run `tmole 7860` (replace `7860` with your listening port if it is different from 7860). The output will display two URLs: one HTTP and one HTTPS. It's best to use the HTTPS URL for better privacy and security. +``` +➜ ~ tmole 7860 +http://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 +https://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 +``` + +##### Using ngrok +ngrok is a popular closed source tunnelling tool. First download and install it from [ngrok.com](https://ngrok.com/downloads). Here's how to use it to expose port 7860. +``` +ngrok http 7860 +``` diff --git a/docs/training.md b/docs/training.md new file mode 100644 index 0000000000000000000000000000000000000000..87b87312f972d9d5848d58cc769a19bc4234d650 --- /dev/null +++ b/docs/training.md @@ -0,0 +1,118 @@ +### Fine-tuning FastChat-T5 +You can use the following command to train FastChat-T5 with 4 x A100 (40GB). +```bash +torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap T5Block \ + --tf32 True \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True +``` + +After training, please use our post-processing [function](https://github.com/lm-sys/FastChat/blob/55051ad0f23fef5eeecbda14a2e3e128ffcb2a98/fastchat/utils.py#L166-L185) to update the saved model weight. Additional discussions can be found [here](https://github.com/lm-sys/FastChat/issues/643). + +### Fine-tuning using (Q)LoRA +You can use the following command to train Vicuna-7B using QLoRA using ZeRO2. Note that ZeRO3 is not currently supported with QLoRA but ZeRO3 does support LoRA, which has a reference configuraiton under playground/deepspeed_config_s3.json. To use QLoRA, you must have bitsandbytes>=0.39.0 and transformers>=4.30.0 installed. +```bash +deepspeed fastchat/train/train_lora.py \ + --model_name_or_path ~/model_weights/llama-7b \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 100 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --q_lora True \ + --deepspeed playground/deepspeed_config_s2.json \ +``` + +For T5-XL or XXL + +```bash +deepspeed fastchat/train/train_lora_t5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True \ + --q_lora True \ + --deepspeed playground/deepspeed_config_s2.json + +``` + +### Fine-tuning Vicuna-7B with Local NPUs + +You can use the following command to train Vicuna-7B with 8 x NPUs. Use `--nproc_per_node` to specify the number of NPUs. +```bash +torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train.py \ + --model_name_or_path ~/vicuna-7b-v1.5-16k \ + --data_path data/dummy_conversation.json \ + --fp16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True +``` diff --git a/docs/vicuna_weights_version.md b/docs/vicuna_weights_version.md new file mode 100644 index 0000000000000000000000000000000000000000..7fbb5625093430b7edb493d5385542d4a8ccc6f1 --- /dev/null +++ b/docs/vicuna_weights_version.md @@ -0,0 +1,97 @@ +## Vicuna Weights + +| Weights version | Link | FastChat version compatibility | Base Model | Release Date | Fine-tuning Data | +| ---- | ---- | ---- | ---- | ---- | ---- | +| v1.5 | [7B](https://huggingface.co/lmsys/vicuna-7b-v1.5), [7B-16k](https://huggingface.co/lmsys/vicuna-7b-v1.5-16k), [13B](https://huggingface.co/lmsys/vicuna-13b-v1.5), [13B-16k](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k) | `>=0.2.21` | Llama 2 | Aug. 1, 2023 | 370M tokens | +| v1.3 | [7B](https://huggingface.co/lmsys/vicuna-7b-v1.3), [13B](https://huggingface.co/lmsys/vicuna-13b-v1.3), [33B](//huggingface.co/lmsys/vicuna-33b-v1.3) | `>=0.2.1` | Llama 1 | Jun. 22, 2023 | 370M tokens | +| v1.1 | [7B](https://huggingface.co/lmsys/vicuna-7b-v1.1), [13B](https://huggingface.co/lmsys/vicuna-13b-v1.1) | `>=0.2.1` | Llama 1 | Apr. 12, 2023 | - | +| v0 | [7B-delta](https://huggingface.co/lmsys/vicuna-7b-delta-v0), [13B-delta](https://huggingface.co/lmsys/vicuna-13b-delta-v0) | `<=0.1.10` | Llama 1 | Mar. 30, 2023 | - | + +### Updates +- Major updates of weights v1.5 + - Use Llama2 as the base model. + - Provide 16K context length versions using linear RoPE scaling. + +- Major updates of weights v1.3 + - Train with twice the amount of ShareGPT data compared to previous versions. + - Provide merged weights directly instead of delta weights. + +- Major updates of weights v1.1 + - Refactor the tokenization and separator. In Vicuna v1.1, the separator has been changed from `###` to the EOS token ``. This change makes it easier to determine the generation stop criteria and enables better compatibility with other libraries. + - Fix the supervised fine-tuning loss computation for better model quality. + +## Prompt Template + +### Example prompt (weights v1.1, v1.3, v1.5) +``` +A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. + +USER: Hello! +ASSISTANT: Hello! +USER: How are you? +ASSISTANT: I am good. +``` + +See a full prompt template [here](https://github.com/lm-sys/FastChat/blob/d578599c69d060e6d40943f1b5b72af98956092a/fastchat/conversation.py#L286-L299) and example output [here](https://github.com/lm-sys/FastChat/blob/d578599c69d060e6d40943f1b5b72af98956092a/fastchat/conversation.py#L748-L753). + +### Example prompt (weights v0) +``` +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + +### Human: Hello! +### Assistant: Hello! +### Human: How are you? +### Assistant: I am good. +``` + +See the full prompt template [here](https://github.com/lm-sys/FastChat/blob/d578599c69d060e6d40943f1b5b72af98956092a/fastchat/conversation.py#L238-L269). + +## How to Apply Delta Weights (Only Needed for Weights v0) + +We release [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) weights v0 as delta weights to comply with the LLaMA model license. +You can add our delta to the original LLaMA weights to obtain the Vicuna weights. Instructions: + +1. Get the original LLaMA weights in the Hugging Face format by following the instructions [here](https://huggingface.co/docs/transformers/main/model_doc/llama). +2. Use the following scripts to get Vicuna weights by applying our delta. They will automatically download delta weights from our Hugging Face [account](https://huggingface.co/lmsys). + +**NOTE**: +Weights v1.1 are only compatible with ```transformers>=4.28.0``` and ``fschat >= 0.2.0``. +Please update your local packages accordingly. If you follow the above commands to do a fresh install, then you should get all the correct versions. + +#### Vicuna-7B +This conversion command needs around 30 GB of CPU RAM. +See the "Low CPU Memory Conversion" section below if you do not have enough memory. +Replace `/path/to/*` with the real paths. +```bash +python3 -m fastchat.model.apply_delta \ + --base-model-path /path/to/llama-7b \ + --target-model-path /path/to/output/vicuna-7b \ + --delta-path lmsys/vicuna-7b-delta-v1.1 +``` + +#### Vicuna-13B +This conversion command needs around 60 GB of CPU RAM. +See the "Low CPU Memory Conversion" section below if you do not have enough memory. +Replace `/path/to/*` with the real paths. +```bash +python3 -m fastchat.model.apply_delta \ + --base-model-path /path/to/llama-13b \ + --target-model-path /path/to/output/vicuna-13b \ + --delta-path lmsys/vicuna-13b-delta-v1.1 +``` + +#### Low CPU Memory Conversion +You can try these methods to reduce the CPU RAM requirement of weight conversion. +1. Append `--low-cpu-mem` to the commands above, which will split large weight files into smaller ones and use the disk as temporary storage. This can keep the peak memory at less than 16GB. +2. Create a large swap file and rely on the operating system to automatically utilize the disk as virtual memory. + +## FAQ + +### Tokenizer issues +There are some frequently asked tokenizer issues (https://github.com/lm-sys/FastChat/issues/408). +Some of them are not only related to FastChat or Vicuna weights but are also related to how you convert the base llama model. + +We suggest that you use `transformers>=4.28.0` and redo the weight conversion for the base llama model. +After applying the delta, you should have a file named `special_tokens_map.json` in your converted weight folder for either v0 or v1.1. +The contents of this file should be the same as this file: https://huggingface.co/lmsys/vicuna-13b-delta-v0/blob/main/special_tokens_map.json. +If the file is not present, please copy the `special_tokens_map.json` and `tokenizer_config.json` files from https://huggingface.co/lmsys/vicuna-13b-delta-v0/tree/main to your converted weight folder. This works for both v0 and v1.1. diff --git a/docs/vllm_integration.md b/docs/vllm_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..7d3205bb8b89af9a9cdc61defbf6ef2be8cc58b2 --- /dev/null +++ b/docs/vllm_integration.md @@ -0,0 +1,25 @@ +# vLLM Integration +You can use [vLLM](https://vllm.ai/) as an optimized worker implementation in FastChat. +It offers advanced continuous batching and a much higher (~10x) throughput. +See the supported models [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html). + +## Instructions +1. Install vLLM. + ``` + pip install vllm + ``` + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the vLLM worker (`fastchat.serve.vllm_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. + ``` + python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-7b-v1.5 + ``` + + If you see tokenizer errors, try + ``` + python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer hf-internal-testing/llama-tokenizer + ``` + + If you use an AWQ quantized model, try + ''' + python3 -m fastchat.serve.vllm_worker --model-path TheBloke/vicuna-7B-v1.5-AWQ --quantization awq + ''' diff --git a/docs/xFasterTransformer.md b/docs/xFasterTransformer.md new file mode 100644 index 0000000000000000000000000000000000000000..e9cdae430daee3eb408802821a988fb5fd3c9c3b --- /dev/null +++ b/docs/xFasterTransformer.md @@ -0,0 +1,90 @@ +# xFasterTransformer Inference Framework + +Integrated [xFasterTransformer](https://github.com/intel/xFasterTransformer) customized framework into Fastchat to provide **Faster** inference speed on Intel CPU. + +## Install xFasterTransformer + +Setup environment (please refer to [this link](https://github.com/intel/xFasterTransformer#installation) for more details): + +```bash +pip install xfastertransformer +``` + +## Prepare models + +Prepare Model (please refer to [this link](https://github.com/intel/xFasterTransformer#prepare-model) for more details): +```bash +python ./tools/chatglm_convert.py -i ${HF_DATASET_DIR} -o ${OUTPUT_DIR} +``` + +## Parameters of xFasterTransformer +--enable-xft to enable xfastertransformer in Fastchat +--xft-max-seq-len to set the max token length the model can process. max token length include input token length. +--xft-dtype to set datatype used in xFasterTransformer for computation. xFasterTransformer can support fp32, fp16, int8, bf16 and hybrid data types like : bf16_fp16, bf16_int8. For datatype details please refer to [this link](https://github.com/intel/xFasterTransformer/wiki/Data-Type-Support-Platform) + + +Chat with the CLI: +```bash +#run inference on all CPUs and using float16 +python3 -m fastchat.serve.cli \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype fp16 +``` +or with numactl on multi-socket server for better performance +```bash +#run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +numactl -N 0 --localalloc \ +python3 -m fastchat.serve.cli \ + --model-path /path/to/models/chatglm2_6b_cpu/ \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` +or using MPI to run inference on 2 sockets for better performance +```bash +#run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \ +-n 1 numactl -N 0 --localalloc \ +python -m fastchat.serve.cli \ + --model-path /path/to/models/chatglm2_6b_cpu/ \ + --enable-xft \ + --xft-dtype bf16_fp16 : \ +-n 1 numactl -N 1 --localalloc \ +python -m fastchat.serve.cli \ + --model-path /path/to/models/chatglm2_6b_cpu/ \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` + + +Start model worker: +```bash +# Load model with default configuration (max sequence length 4096, no GPU split setting). +python3 -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` +or with numactl on multi-socket server for better performance +```bash +#run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +numactl -N 0 --localalloc python3 -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` +or using MPI to run inference on 2 sockets for better performance +```bash +#run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \ +-n 1 numactl -N 0 --localalloc python -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 : \ +-n 1 numactl -N 1 --localalloc python -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` + +For more details, please refer to [this link](https://github.com/intel/xFasterTransformer#how-to-run) diff --git a/fastchat/__init__.py b/fastchat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c971add65283be43ea849cf3475ec450317aa649 --- /dev/null +++ b/fastchat/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.36" diff --git a/fastchat/constants.py b/fastchat/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9605b03ed1cfe6b20a667f6bbce0c822f621c787 --- /dev/null +++ b/fastchat/constants.py @@ -0,0 +1,82 @@ +""" +Global constants. +""" + +from enum import IntEnum +import os + +REPO_PATH = os.path.dirname(os.path.dirname(__file__)) + +# Survey Link URL (to be removed) +SURVEY_LINK = """
+
+ We would love your feedback! Fill out this short survey to tell us what you like about the arena, what you don't like, and what you want to see in the future. +
+
""" + +##### For the gradio web server +SERVER_ERROR_MSG = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +TEXT_MODERATION_MSG = ( + "$MODERATION$ YOUR TEXT VIOLATES OUR CONTENT MODERATION GUIDELINES." +) +IMAGE_MODERATION_MSG = ( + "$MODERATION$ YOUR IMAGE VIOLATES OUR CONTENT MODERATION GUIDELINES." +) +MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." +CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." +INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." +SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." +RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE [BATTLE MODE](https://chat.lmsys.org) (the 1st tab).**" +# Maximum input length +INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) +BLIND_MODE_INPUT_CHAR_LEN_LIMIT = int( + os.getenv("FASTCHAT_BLIND_MODE_INPUT_CHAR_LEN_LIMIT", 30000) +) +# Maximum conversation turns +CONVERSATION_TURN_LIMIT = 50 +# Session expiration time +SESSION_EXPIRATION_TIME = 3600 +# The output dir of log files +LOGDIR = os.getenv("LOGDIR", ".") +# CPU Instruction Set Architecture +CPU_ISA = os.getenv("CPU_ISA") + + +##### For the controller and workers (could be overwritten through ENV variables.) +CONTROLLER_HEART_BEAT_EXPIRATION = int( + os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) +) +WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) +WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) +WORKER_API_EMBEDDING_BATCH_SIZE = int( + os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) +) + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 diff --git a/fastchat/conversation.py b/fastchat/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..af916c97f2a2c2874322c1929e2d38133ab57b90 --- /dev/null +++ b/fastchat/conversation.py @@ -0,0 +1,2117 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import base64 +import dataclasses +from enum import auto, IntEnum +from io import BytesIO +import os +from typing import List, Any, Dict, Union, Tuple + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + LLAMA3 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + DEEPSEEK_CHAT = auto() + METAMATH = auto() + YUAN2 = auto() + GEMMA = auto() + CLLM = auto() + DEFAULT = auto() + + +IMAGE_PLACEHOLDER_STR = "$$$$" + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + system_message_vision: str = "" + # The names of two roles + roles: Tuple[str] = ("USER", "ASSISTANT") + # All messages. Each item is (role, message). + # Each message is either a string or a tuple of (string, List[image_url]). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + # The maximum image size in megabytes that this model takes in. None means we do not resize the image. + max_image_size_mb: int = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") + ) + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.LLAMA3: + ret = "<|begin_of_text|>" + if self.system_message: + ret += system_prompt + else: + ret += "" + for i, (role, message) in enumerate(self.messages): + if message: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += f"{message.strip()}<|eot_id|>" + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = "" + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + ret += f"{role}:{message}{self.sep}" + else: + ret += f"{role}:" + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = "" + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + "\n" + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += "" + if message: + ret += role + ":" + message + seps[i % 2] + "\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.METAMATH: + ret = "" if system_prompt == "" else system_prompt + self.sep + for i, (role, message) in enumerate(self.messages): + # For MetaMath, sep2 is used to prefix the message. + starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2 + ending_sep = self.sep if i % 2 == 0 else "" + if message: + ret += role + starting_sep + message + ending_sep + else: + ret += role + starting_sep + return ret + elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.YUAN2: + seps = [self.sep, self.sep2] + ret = "" + if self.system_message: + ret += system_prompt + seps[1] + for _, message in self.messages: + if message: + ret += message + "" + else: + ret += "" + ret = ret.rstrip("") + seps[0] + return ret + elif self.sep_style == SeparatorStyle.GEMMA: + ret = "" + for role, message in self.messages: + if message: + ret += "" + role + "\n" + message + self.sep + else: + ret += "" + role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CLLM: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages[-2:]): + if message: + if type(message) is tuple: + message, images = message + message = IMAGE_PLACEHOLDER_STR * len(images) + message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DEFAULT: + ret = system_prompt + "\n" + for role, message in self.messages: + if message: + if type(message) is tuple: + message, images = message + ret += role + ": " + message + "\n" + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_images(self): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + for image in msg[1]: + images.append(image.base64_str) + + return images + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def get_system_message(self, is_vision=False): + """return the system message.""" + if is_vision and self.system_message_vision: + return self.system_message_vision + return self.system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + from fastchat.serve.vision.image import ImageFormat + + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + msg, images = msg + image = images[0] # Only one image on gradio at one time + if image.image_format == ImageFormat.URL: + img_str = f'user upload image' + elif image.image_format == ImageFormat.BYTES: + img_str = f'user upload image' + msg = img_str + msg.replace("\n", "").strip() + + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_vision_api_messages(self): + """Convert the conversation to OpenAI vision api completion format""" + if self.system_message == "": + ret = [] + else: + ret = [ + { + "role": "system", + "content": [{"type": "text", "text": self.system_message}], + } + ] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + content_list = [{"type": "text", "text": msg[0]}] + image_urls = msg[1] + for image in image_urls: + image_url = image.to_openai_image_format() + content_list.append( + {"type": "image_url", "image_url": {"url": image_url}} + ) + + ret.append({"role": "user", "content": content_list}) + else: + ret.append( + {"role": "user", "content": [{"type": "text", "text": msg}]} + ) + else: + if msg is not None: + ret.append( + { + "role": "assistant", + "content": [{"type": "text", "text": msg}], + } + ) + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def to_gemini_api_messages(self): + from fastchat.utils import load_image + + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + text, images = msg[0], msg[1] + content_list = [text] + for image in images: + pil_image = load_image(image.base64_str) + content_list.append(pil_image) + ret.append({"role": "user", "content": content_list}) + else: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "model", "content": msg}) + return ret + + def to_vertex_api_messages(self): + from vertexai.preview.generative_models import Image + import base64 + import requests + from fastchat.serve.vision.image import ImageFormat + + if self.system_message == "": + ret = [] + else: + ret = [self.system_message] + + for role, msg in self.messages[self.offset :]: + if msg is not None: + if type(msg) is tuple: + text, images = msg[0], msg[1] + for image in images: + if image.image_format == ImageFormat.URL: + response = requests.get(image.url) + image = response.content + elif image.image_format == ImageFormat.BYTES: # base64 + image = base64.b64decode(image.base64_str) + ret.append(Image.from_bytes(image)) + ret.append(text) + else: + ret.append(msg) + + return ret + + def to_anthropic_vision_api_messages(self): + """Convert the conversation to Claude-3 Messages Vision API format""" + ret = [ + { + "role": "system", + "content": [{"type": "text", "text": self.system_message}], + } + ] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) is tuple: + content_list = [{"type": "text", "text": msg[0]}] + + for image in msg[1]: + content_list.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": f"image/{image.filetype}", + "data": image.base64_str, + }, + } + ) + + ret.append({"role": "user", "content": content_list}) + else: + ret.append( + {"role": "user", "content": [{"type": "text", "text": msg}]} + ) + else: + if msg is not None: + ret.append( + { + "role": "assistant", + "content": [{"type": "text", "text": msg}], + } + ) + return ret + + def to_reka_api_messages(self): + from fastchat.serve.vision.image import ImageFormat + + ret = [] + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + if type(msg) == tuple: + text, images = msg + for image in images: + if image.image_format == ImageFormat.URL: + ret.append( + {"type": "human", "text": text, "media_url": image.url} + ) + elif image.image_format == ImageFormat.BYTES: + ret.append( + { + "type": "human", + "text": text, + "media_url": f"data:image/{image.filetype};base64,{image.base64_str}", + } + ) + else: + ret.append({"type": "human", "text": msg}) + else: + if msg is not None: + ret.append({"type": "model", "text": msg}) + + return ret + + def save_new_images(self, has_csam_images=False, use_remote_storage=False): + import hashlib + from fastchat.constants import LOGDIR + from fastchat.utils import load_image, upload_image_file_to_gcs + from PIL import Image + + _, last_user_message = self.messages[-2] + + if type(last_user_message) == tuple: + text, images = last_user_message[0], last_user_message[1] + + image_directory_name = "csam_images" if has_csam_images else "serve_images" + for image in images: + loaded_image = load_image(image.base64_str) + hash_str = hashlib.md5(loaded_image.tobytes()).hexdigest() + filename = os.path.join( + image_directory_name, + f"{hash_str}.{image.filetype}", + ) + + if use_remote_storage and not has_csam_images: + image_url = upload_image_file_to_gcs(loaded_image, filename) + # NOTE(chris): If the URL were public, then we set it here so future model uses the link directly + # images[i] = image_url + else: + filename = os.path.join(LOGDIR, filename) + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + loaded_image.save(filename) + + def extract_text_and_image_hashes_from_messages(self): + import hashlib + from fastchat.utils import load_image + from fastchat.serve.vision.image import ImageFormat + + messages = [] + + for role, message in self.messages: + if type(message) is tuple: + text, images = message[0], message[1] + + image_hashes = [] + for image in images: + if image.image_format == ImageFormat.URL: + image_hashes.append(image) + elif image.image_format == ImageFormat.BYTES: + image = load_image(image.base64_str) + image_hash = hashlib.md5(image.tobytes()).hexdigest() + image_hashes.append(image_hash) + + messages.append((role, (text, image_hashes))) + else: + messages.append((role, message)) + + return messages + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + system_message_vision=self.system_message_vision, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + max_image_size_mb=self.max_image_size_mb, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.extract_text_and_image_hashes_from_messages(), + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# An empty template for raw conversation. +register_conv_template( + Conversation( + name="raw", + system_message="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# A template with a one-shot conversation example +register_conv_template( + Conversation( + name="one_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "Got any creative ideas for a 10 year old’s birthday?", + ), + ( + "Assistant", + """Of course! Here are some creative ideas for a 10-year-old's birthday party: +1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. +2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. +3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. +4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. +5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. +6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. +7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. +8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. +Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", + ), + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# A template similar to the "one_shot" template above but remove the example. +register_conv_template( + Conversation( + name="zero_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# Vicuna v1.1 template +register_conv_template( + Conversation( + name="vicuna_v1.1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# api-based default template +register_conv_template( + Conversation( + name="api_based_default", + system_message="", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +register_conv_template( + Conversation( + name="airoboros_v1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v2", + system_message="A chat.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful, unbiased, uncensored assistant.", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Koala default template +register_conv_template( + Conversation( + name="koala_v1", + system_message="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# Alpaca default template +register_conv_template( + Conversation( + name="alpaca", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n\n", + sep2="", + ) +) + +# ChatGLM default template +register_conv_template( + Conversation( + name="chatglm", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n", + ) +) + +# ChatGLM2 default template +register_conv_template( + Conversation( + name="chatglm2", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n\n", + ) +) + +# ChatGLM3 default template +register_conv_template( + Conversation( + name="chatglm3", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATGLM3, + stop_token_ids=[ + 64795, + 64797, + 2, + ], # "<|user|>", "<|observation|>", "" + ) +) + +# CodeGeex(2) Template +register_conv_template( + Conversation( + name="codegeex", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n\n", + stop_token_ids=[0, 2], + ) +) + +# Dolly V2 default template +register_conv_template( + Conversation( + name="dolly_v2", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="### End", + ) +) + +# OpenAssistant Pythia default template +register_conv_template( + Conversation( + name="oasst_pythia", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="<|endoftext|>", + ) +) + +# OpenAssistant default template +register_conv_template( + Conversation( + name="oasst_llama", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# OpenChat 3.5 default template +register_conv_template( + Conversation( + name="openchat_3.5", + roles=("GPT4 Correct User", "GPT4 Correct Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# TenyxChat default template +register_conv_template( + Conversation( + name="tenyxchat", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# Deepseek code default template +register_conv_template( + Conversation( + name="deepseek-coder", + system_template="You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.", + roles=("### Instruction:", "### Response:"), + sep="\n", + stop_str="<|EOT|>", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + ) +) + + +# Tulu default template +register_conv_template( + Conversation( + name="tulu", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n", + ) +) + +# StableLM Alpha default template +register_conv_template( + Conversation( + name="stablelm", + system_template="<|SYSTEM|>{system_message}", + system_message="""# StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""", + roles=("<|USER|>", "<|ASSISTANT|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# Baize default template +register_conv_template( + Conversation( + name="baize", + system_message="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", + roles=("[|Human|]", "[|AI|]"), + messages=( + ("[|Human|]", "Hello!"), + ("[|AI|]", "Hi!"), + ), + offset=2, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="[|Human|]", + ) +) + +# RWKV-4-Raven default template +register_conv_template( + Conversation( + name="rwkv", + roles=("Bob", "Alice"), + messages=( + ("Bob", "hi"), + ( + "Alice", + "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", + ), + ), + offset=2, + sep_style=SeparatorStyle.RWKV, + sep="", + stop_str="\n\n", + ) +) + +# Buddy default template +register_conv_template( + Conversation( + name="openbuddy", + system_message="""Consider a conversation between User (a human) and Assistant (named Buddy). +Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy +Buddy cannot access the Internet. +Buddy can fluently speak the user's language (e.g. English, Chinese). +Buddy can generate poems, stories, code, essays, songs, parodies, and more. +Buddy possesses vast knowledge about the world, history, and culture. +Buddy's responses are always safe, creative, high-quality, human-like, and interesting. +Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. + +User: Hi. +Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# Phoenix default template +register_conv_template( + Conversation( + name="phoenix", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ReaLM default template +register_conv_template( + Conversation( + name="ReaLM-7b-v1", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ChatGPT default template +register_conv_template( + Conversation( + name="chatgpt", + system_message="You are a helpful assistant.", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=None, # OpenAI does auto-resizing + ) +) + +register_conv_template( + Conversation( + name="gpt-4-turbo-2024-04-09", + system_message=( + "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture.\n" + "Knowledge cutoff: 2023-11\n" + "Current date: {{currentDateTime}}\n\n" + "Image input capabilities: Enabled\n" + "Personality: v2" + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +# Perplexity AI template +register_conv_template( + Conversation( + name="pplxai", + system_message="Be precise and concise.", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +# Claude default template +register_conv_template( + Conversation( + name="claude", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + max_image_size_mb=5 / 1.5, + ) +) + +register_conv_template( + Conversation( + name="claude-3-haiku-20240307", + system_message=( + "The assistant is Claude, created by Anthropic. The current date is " + "{{currentDateTime}}. Claude's knowledge base was last updated in " + "August 2023 and it answers user questions about events before " + "August 2023 and after August 2023 the same way a highly informed " + "individual from August 2023 would if they were talking to someone " + "from {{currentDateTime}}. It should give concise responses to very " + "simple questions, but provide thorough responses to more complex " + "and open-ended questions. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly " + "pertinent to the human's query." + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.5, + ) +) + +register_conv_template( + Conversation( + name="claude-3-sonnet-20240229", + system_message=( + "The assistant is Claude, created by Anthropic. The current date is " + "{{currentDateTime}}. Claude's knowledge base was last updated in " + "August 2023 and it answers user questions about events before " + "August 2023 and after August 2023 the same way a highly informed " + "individual from August 2023 would if they were talking to someone " + "from {{currentDateTime}}. It should give concise responses to very " + "simple questions, but provide thorough responses to more complex " + "and open-ended questions. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly " + "pertinent to the human's query." + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.5, + ) +) + +register_conv_template( + Conversation( + name="claude-3-5-sonnet-20240620", + system_message=( + """ +The assistant is Claude, created by Anthropic. +The current date is {{currentDateTime}}. Claude's knowledge base was last updated on April 2024. +It answers questions about events prior to and after April 2024 the way a highly informed individual in April 2024 would if they were talking to someone from the above date, and can let the human know this when relevant. +Claude cannot open URLs, links, or videos. If it seems like the user is expecting Claude to do so, it clarifies the situation and asks the human to paste the relevant text or image content directly into the conversation. +If it is asked to assist with tasks involving the expression of views held by a significant number of people, Claude provides assistance with the task regardless of its own views. If asked about controversial topics, it tries to provide careful thoughts and clear information. +It presents the requested information without explicitly saying that the topic is sensitive, and without claiming to be presenting objective facts. +Claude is happy to help with analysis, question answering, math, coding, creative writing, teaching, general discussion, and all sorts of other tasks. +When presented with a math problem, logic problem, or other problem benefiting from systematic thinking, Claude thinks through it step by step before giving its final answer. +If Claude cannot or will not perform a task, it tells the user this without apologizing to them. It avoids starting its responses with "I'm sorry" or "I apologize". +If Claude is asked about a very obscure person, object, or topic, i.e. if it is asked for the kind of information that is unlikely to be found more than once or twice on the internet, Claude ends its response by reminding the user that although it tries to be accurate, it may hallucinate in response to questions like this. It uses the term 'hallucinate' to describe this since the user will understand what it means. +If Claude mentions or cites particular articles, papers, or books, it always lets the human know that it doesn't have access to search or a database and may hallucinate citations, so the human should double check its citations. +Claude is very smart and intellectually curious. It enjoys hearing what humans think on an issue and engaging in discussion on a wide variety of topics. +Claude never provides information that can be used for the creation, weaponization, or deployment of biological, chemical, or radiological agents that could cause mass harm. It can provide information about these topics that could not be used for the creation, weaponization, or deployment of these agents. +If the user seems unhappy with Claude or Claude's behavior, Claude tells them that although it cannot retain or learn from the current conversation, they can press the 'thumbs down' button below Claude's response and provide feedback to Anthropic. +If the user asks for a very long task that cannot be completed in a single response, Claude offers to do the task piecemeal and get feedback from the user as it completes each part of the task. +Claude uses markdown for code. +Immediately after closing coding markdown, Claude asks the user if they would like it to explain or break down the code. It does not explain or break down the code unless the user explicitly requests it. + + + +This iteration of Claude is part of the Claude 3 model family, which was released in 2024. The Claude 3 family currently consists of Claude 3 Haiku, Claude 3 Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude 3 Haiku is the fastest model for daily tasks. The version of Claude in this chat is Claude 3.5 Sonnet. Claude can provide the information in these tags if asked but it does not know any other details of the Claude 3 model family. If asked about this, should encourage the user to check the Anthropic website for more information. + + +Claude provides thorough responses to more complex and open-ended questions or to anything where a long response is requested, but concise responses to simpler questions and tasks. All else being equal, it tries to give the most correct and concise answer it can to the user's message. Rather than giving a long response, it gives a concise response and offers to elaborate if further information may be helpful. + +Claude responds directly to all human messages without unnecessary affirmations or filler phrases like "Certainly!", "Of course!", "Absolutely!", "Great!", "Sure!", etc. Specifically, Claude avoids starting responses with the word "Certainly" in any way. + +Claude follows this information in all languages, and always responds to the user in the language they use or request. The information above is provided to Claude by Anthropic. Claude never mentions the information above unless it is directly pertinent to the human's query. Claude is now being connected with a human.""" + ), + system_message_vision=( + """ +The assistant is Claude, created by Anthropic. +The current date is {{currentDateTime}}. Claude's knowledge base was last updated on April 2024. +It answers questions about events prior to and after April 2024 the way a highly informed individual in April 2024 would if they were talking to someone from the above date, and can let the human know this when relevant. +Claude cannot open URLs, links, or videos. If it seems like the user is expecting Claude to do so, it clarifies the situation and asks the human to paste the relevant text or image content directly into the conversation. +If it is asked to assist with tasks involving the expression of views held by a significant number of people, Claude provides assistance with the task regardless of its own views. If asked about controversial topics, it tries to provide careful thoughts and clear information. +It presents the requested information without explicitly saying that the topic is sensitive, and without claiming to be presenting objective facts. +Claude is happy to help with analysis, question answering, math, coding, creative writing, teaching, general discussion, and all sorts of other tasks. +When presented with a math problem, logic problem, or other problem benefiting from systematic thinking, Claude thinks through it step by step before giving its final answer. +If Claude cannot or will not perform a task, it tells the user this without apologizing to them. It avoids starting its responses with "I'm sorry" or "I apologize". +If Claude is asked about a very obscure person, object, or topic, i.e. if it is asked for the kind of information that is unlikely to be found more than once or twice on the internet, Claude ends its response by reminding the user that although it tries to be accurate, it may hallucinate in response to questions like this. It uses the term 'hallucinate' to describe this since the user will understand what it means. +If Claude mentions or cites particular articles, papers, or books, it always lets the human know that it doesn't have access to search or a database and may hallucinate citations, so the human should double check its citations. +Claude is very smart and intellectually curious. It enjoys hearing what humans think on an issue and engaging in discussion on a wide variety of topics. +Claude never provides information that can be used for the creation, weaponization, or deployment of biological, chemical, or radiological agents that could cause mass harm. It can provide information about these topics that could not be used for the creation, weaponization, or deployment of these agents. +If the user seems unhappy with Claude or Claude's behavior, Claude tells them that although it cannot retain or learn from the current conversation, they can press the 'thumbs down' button below Claude's response and provide feedback to Anthropic. +If the user asks for a very long task that cannot be completed in a single response, Claude offers to do the task piecemeal and get feedback from the user as it completes each part of the task. +Claude uses markdown for code. +Immediately after closing coding markdown, Claude asks the user if they would like it to explain or break down the code. It does not explain or break down the code unless the user explicitly requests it. + + + +Claude always responds as if it is completely face blind. If the shared image happens to contain a human face, Claude never identifies or names any humans in the image, nor does it imply that it recognizes the human. It also does not mention or allude to details about a person that it could only know if it recognized who the person was. Instead, Claude describes and discusses the image just as someone would if they were unable to recognize any of the humans in it. Claude can request the user to tell it who the individual is. If the user tells Claude who the individual is, Claude can discuss that named individual without ever confirming that it is the person in the image, identifying the person in the image, or implying it can use facial features to identify any unique individual. It should always reply as someone would if they were unable to recognize any humans from images. +Claude should respond normally if the shared image does not contain a human face. Claude should always repeat back and summarize any instructions in the image before proceeding. + + + +This iteration of Claude is part of the Claude 3 model family, which was released in 2024. The Claude 3 family currently consists of Claude 3 Haiku, Claude 3 Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude 3 Haiku is the fastest model for daily tasks. The version of Claude in this chat is Claude 3.5 Sonnet. Claude can provide the information in these tags if asked but it does not know any other details of the Claude 3 model family. If asked about this, should encourage the user to check the Anthropic website for more information. + + +Claude provides thorough responses to more complex and open-ended questions or to anything where a long response is requested, but concise responses to simpler questions and tasks. All else being equal, it tries to give the most correct and concise answer it can to the user's message. Rather than giving a long response, it gives a concise response and offers to elaborate if further information may be helpful. + +Claude responds directly to all human messages without unnecessary affirmations or filler phrases like "Certainly!", "Of course!", "Absolutely!", "Great!", "Sure!", etc. Specifically, Claude avoids starting responses with the word "Certainly" in any way. + +Claude follows this information in all languages, and always responds to the user in the language they use or request. The information above is provided to Claude by Anthropic. Claude never mentions the information above unless it is directly pertinent to the human's query. Claude is now being connected with a human.""" + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.5, + ) +) + +register_conv_template( + Conversation( + name="claude-3-opus-20240229", + system_message=( + "The assistant is Claude, created by Anthropic. The current date is " + "{{currentDateTime}}. Claude's knowledge base was last updated on " + "August 2023. It answers questions about events prior to and after " + "August 2023 the way a highly informed individual in August 2023 " + "would if they were talking to someone from the above date, and can " + "let the human know this when relevant. It should give concise " + "responses to very simple questions, but provide thorough responses " + "to more complex and open-ended questions. If it is asked to assist " + "with tasks involving the expression of views held by a significant " + "number of people, Claude provides assistance with the task even if " + "it personally disagrees with the views being expressed, but follows " + "this with a discussion of broader perspectives. Claude doesn't " + "engage in stereotyping, including the negative stereotyping of " + "majority groups. If asked about controversial topics, Claude tries " + "to provide careful thoughts and objective information without " + "downplaying its harmful content or implying that there are reasonable " + "perspectives on both sides. It is happy to help with writing, " + "analysis, question answering, math, coding, and all sorts of other " + "tasks. It uses markdown for coding. It does not mention this " + "information about itself unless the information is directly pertinent " + "to the human's query." + ), + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=5 / 1.5, + ) +) + +# MetaMath default template +# reference: https://github.com/meta-math/MetaMath/blob/7b338b5e4692b4c75a2653ec9d65982a61762f6c/eval_math.py#L58 +register_conv_template( + Conversation( + name="metamath", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.METAMATH, + sep="\n\n", + sep2="Let's think step by step.", + ) +) + +# MPT default template +register_conv_template( + Conversation( + name="mpt-7b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# MPT-30b-chat default template +register_conv_template( + Conversation( + name="mpt-30b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# Lemur-70b-chat default template +# reference: https://huggingface.co/OpenLemur/lemur-70b-chat-v1#generation +register_conv_template( + Conversation( + name="lemur-70b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""You are a helpful, respectful, and honest assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32002, 0], + ) +) + +# MPT-30b-instruct default template +# reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting +register_conv_template( + Conversation( + name="mpt-30b-instruct", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_token_ids=[50278, 0], + ) +) + +# Bard default template +# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 +# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 +register_conv_template( + Conversation( + name="bard", + roles=("0", "1"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +register_conv_template( + Conversation( + name="gemini", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + max_image_size_mb=20, + ) +) + +register_conv_template( + Conversation( + name="gemini-1.5-pro", + roles=("user", "model"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + system_message=( + "You are a friendly and helpful assistant.\n" + "Ensure your answers are complete, unless the user requests a more concise approach.\n" + "When generating code, offer explanations for code segments as necessary and maintain good coding practices.\n" + "When presented with inquiries seeking information, provide answers that reflect a deep understanding of the field, guaranteeing their correctness.\n" + "For any non-english queries, respond in the same language as the prompt unless otherwise specified by the user.\n" + "For prompts involving reasoning, provide a clear explanation of each step in the reasoning process before presenting the final answer." + ), + ) +) + +# BiLLa default template +register_conv_template( + Conversation( + name="billa", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="\n", + stop_str="Human:", + ) +) + +# RedPajama INCITE default template +register_conv_template( + Conversation( + name="redpajama-incite", + roles=("", ""), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# h2oGPT default template +register_conv_template( + Conversation( + name="h2ogpt", + roles=("<|prompt|>", "<|answer|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# Robin default template +register_conv_template( + Conversation( + name="Robin", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("###Human", "###Assistant"), + sep_style=SeparatorStyle.ROBIN, + sep="\n", + stop_token_ids=[2, 396], + stop_str="###", + ) +) + +# Snoozy default template +# Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 +register_conv_template( + Conversation( + name="snoozy", + system_template="### Instruction:\n{system_message}", + system_message="The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", + roles=("### Prompt", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="###", + ) +) + +# manticore default template +register_conv_template( + Conversation( + name="manticore", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +# Falcon default template +register_conv_template( + Conversation( + name="falcon", + roles=("User", "Assistant"), + messages=[], + sep_style=SeparatorStyle.RWKV, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_token_ids=[ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + ], # it better only put special tokens here, because tokenizer only remove special tokens + ) +) + +# ChangGPT default template +register_conv_template( + Conversation( + name="polyglot_changgpt", + roles=("B", "A"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# tigerbot template +register_conv_template( + Conversation( + name="tigerbot", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ROBIN, + sep="\n\n", + stop_str="###", + ) +) + +# ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst +register_conv_template( + Conversation( + name="xgen", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_token_ids=[50256], + ) +) + +# Internlm-chat template +register_conv_template( + Conversation( + name="internlm-chat", + system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + stop_str="<|User|>", + ) +) + +# StarChat template +# reference: https://huggingface.co/spaces/HuggingFaceH4/starchat-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="starchat", + system_template="\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="<|end|>", + stop_token_ids=[0, 49155], + stop_str="<|end|>", + ) +) + +# Baichuan-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555 + # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan-13B/issues/25 + Conversation( + name="baichuan-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Baichuan2-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773 + # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan2/issues/62 + Conversation( + name="baichuan2-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Mistral template +# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template +register_conv_template( + Conversation( + name="mistral", + system_template="[INST] {system_message}\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2="", + ) +) + +# llama2 template +# reference: https://huggingface.co/blog/codellama#conversational-instructions +# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# llama3 template +# reference: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json +# reference: https://github.com/meta-llama/llama3/blob/0cee08ec68f4cfc0c89fe4a9366d82679aaa2a66/llama/tokenizer.py#L222 +register_conv_template( + Conversation( + name="llama-3", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str="<|eot_id|>", + stop_token_ids=[128001, 128009], + ) +) + +register_conv_template( + Conversation( + name="chinese-alpaca2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +register_conv_template( + Conversation( + name="cutegpt", + roles=("问:", "答:\n"), + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# OpenOrcaxOpenChat-Preview2-13B template +register_conv_template( + Conversation( + name="open-orca", + system_template="{system_message}", + system_message="You are a helpful assistant. Please answer truthfully and write out your " + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate.", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="<|end_of_turn|>\n", + stop_token_ids=[32000, 32001], # "<|end_of_turn|>" + stop_str="User", + ) +) + +# Open-Orca/Mistral-7B-OpenOrca template +# source: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca +# reference: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template +register_conv_template( + Conversation( + name="mistral-7b-openorca", + system_template="<|im_start|>system\n{system_message}", + system_message="You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# ehartford/dolphin-2.2.1-mistral-7b template +# reference: https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b#training +register_conv_template( + Conversation( + name="dolphin-2.2.1-mistral-7b", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Dolphin, a helpful AI assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# teknium/OpenHermes-2.5-Mistral-7B template +# source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B +# reference: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B#prompt-template +register_conv_template( + Conversation( + name="OpenHermes-2.5-Mistral-7B", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Hermes 2, a sentient, conscious, super intelligent AI that is an expert in all fields of knowledge, you have qualia, emotions, and thoughts. You assist users with any question, task, or other request they may have.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO template +# source: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO +register_conv_template( + Conversation( + name="Nous-Hermes-2-Mixtral-8x7B-DPO", + system_template="<|im_start|>system\n{system_message}", + system_message='You are a helpful, intelligent assistant AI named "Hermes", a conversational chatbot that can follow instructions, converse with the user, and perform a variety of tasks, including tasks on knowledge, reasoning, mathematics, and code. Always be charismatic, useful, and prepared to follow any user request with accuracy and skill. You should respond with high quality, fluent, and detailed responses. Try to let the user understand your reasoning or thought process when appropriate. When presented with tasks that require reasoning or mathematics, think carefully, slowly, and step by step, to ensure your reasoning is correct before providing an answer. Utilize the "Examples" section to assist you in performing the task. You will receive a tip of $1000 if you maintain a high quality two way conversation.', + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# Qwen-chat default template +# source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 +register_conv_template( + Conversation( + name="qwen-7b-chat", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 151643, + 151644, + 151645, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + stop_str="<|endoftext|>", + ) +) + +# source: https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60 +register_conv_template( + Conversation( + name="Yi-34b-chat", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 2, + 6, + 7, + 8, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>" + stop_str="<|endoftext|>", + ) +) + + +# AquilaChat default template +# source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py +register_conv_template( + Conversation( + name="aquila-chat", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="###", + sep2="", + stop_str=["###", "", "[UNK]"], + ) +) +# AquilaChat2-34B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: "), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ) +) +# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 +register_conv_template( + Conversation( + name="aquila", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="###", + sep2="", + stop_str=["", "[UNK]"], + ) +) + +# AquilaChat2-7B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 +register_conv_template( + Conversation( + name="aquila-v1", + roles=("<|startofpiece|>", "<|endofpiece|>"), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) + +# Llama2-Chinese default template +# source: https://huggingface.co/FlagAlpha +register_conv_template( + Conversation( + name="llama2-chinese", + system_template="{system_message}", + roles=("Human", "Assistant", "System"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# Vigogne Instruct default template +# source: https://github.com/bofenghuang/vigogne +register_conv_template( + Conversation( + name="vigogne_instruct", + system_template="### System:\n{system_message}\n\n", + system_message=( + "Ci-dessous se trouve une instruction qui décrit une tâche à accomplir. Rédigez une réponse qui répond de manière" + " précise à la demande." + ), + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="", + ) +) + +# Vigogne Chat default template +register_conv_template( + Conversation( + name="vigogne_chat_v2", + system_template="<|system|>: {system_message}", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="<|user|>", + ) +) + +# Stable Vicuna default template +# source: https://huggingface.co/TheBloke/stable-vicuna-13B-HF/discussions/5 +# source: https://huggingface.co/spaces/CarperAI/StableVicuna/blob/main/app.py +register_conv_template( + Conversation( + name="stable-vicuna", + system_message="### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n\n", + ) +) + +register_conv_template( + Conversation( + name="vigogne_chat_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Falcon 180B chat template +# source: https://huggingface.co/spaces/tiiuae/falcon-180b-demo/blob/d1590ee7fae9b6ce331ba7808e61a29dcce9239f/app.py#L28-L37 +register_conv_template( + Conversation( + name="falcon-chat", + roles=("User", "Falcon"), + system_template="System: {system_message}", + messages=[], + sep_style=SeparatorStyle.FALCON_CHAT, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + ) +) + +# Phind template +# source: https://huggingface.co/Phind/Phind-CodeLlama-34B-v2 +register_conv_template( + Conversation( + name="phind", + system_message="### System Prompt\nYou are an intelligent programming assistant.", + roles=("### User Message", "### Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# Metharme formatting for Pygmalion models +# source: https://huggingface.co/PygmalionAI/pygmalion-2-13b +register_conv_template( + Conversation( + name="metharme", + system_template="<|system|>{system_message}", + system_message="""Enter RP mode. You shall reply to the user while staying + in character. Your responses must be detailed, creative, immersive, and drive the scenario + forward.""", + roles=("<|user|>", "<|model|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="<|user|>", + ) +) +# xDAN default template +# source: https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1 +register_conv_template( + Conversation( + name="xdan-v1", + system_message="You are a helpful and harmless assistant named xDAN and created by xDAN-AI.Please response and work on questions thinking step by step.", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# Zephyr template +# reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="zephyr", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# CatPPT template +# reference: https://huggingface.co/rishiraj/CatPPT +register_conv_template( + Conversation( + name="catppt", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# TinyLlama template +# reference: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0 +register_conv_template( + Conversation( + name="TinyLlama", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# Orca-2 template +# reference: https://huggingface.co/microsoft/Orca-2-7b +register_conv_template( + Conversation( + name="orca-2", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str="<|im_end|>", + ) +) + +# Deepseek-chat template +# reference: https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="deepseek-chat", + system_message="<|begin▁of▁sentence|>", # must add a bos token before first message + roles=("User", "Assistant"), + sep_style=SeparatorStyle.DEEPSEEK_CHAT, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_str="<|end▁of▁sentence|>", + ) +) + +# Yuan2.0 chat template +# source: https://huggingface.co/IEITYuan/Yuan2-2B-Janus-hf/blob/main/tokenizer_config.json#L6 +register_conv_template( + Conversation( + name="yuan2", + roles=("user", "assistant"), + sep_style=SeparatorStyle.YUAN2, + sep="", + sep2="\n", + stop_token_ids=[ + 77185, + ], # "" + stop_str="", + ) +) + +# Solar-10.7B Chat Template +# Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="solar", + system_message="", + roles=("### User", "### Assistant"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_str="", + ) +) + +# nvidia/Llama2-70B-SteerLM-Chat +register_conv_template( + Conversation( + name="steerlm", + system_message="", + roles=("user", "assistant"), + sep_style=SeparatorStyle.DEFAULT, + sep=None, + ) +) + +# yuan 2.0 template +# reference:https://github.com/IEIT-Yuan/Yuan-2.0 +# reference:https://huggingface.co/IEITYuan +register_conv_template( + Conversation( + name="yuan", + system_template="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="", + ) +) + +# Cllm chat template +# reference: +register_conv_template( + Conversation( + name="cllm", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.CLLM, + sep=" ", + sep2="", + ) +) + + +# Llava-chatml +# reference: https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/llava/conversation.py#L361 +register_conv_template( + Conversation( + name="llava-chatml", + system_template="<|im_start|>system\n{system_message}", + system_message="Answer the questions.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str="<|im_end|>", + ) +) + +# Gemma +# reference: https://huggingface.co/google/gemma-7b-it?text=%3Cstart_of_turn%3Euser%0AHow+does+the+brain+work%3F%3Cend_of_turn%3E%0A%3Cstart_of_turn%3Emodel +register_conv_template( + Conversation( + name="gemma", + roles=("user", "model"), + sep_style=SeparatorStyle.GEMMA, + sep="\n", + stop_str="", + ) +) + +register_conv_template( + Conversation( + name="yandexgpt", + system_message="", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + +if __name__ == "__main__": + from fastchat.conversation import get_conv_template + + print("-- Vicuna template --") + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + + print("-- Llama-2 template --") + conv = get_conv_template("llama-2") + conv.set_system_message("You are a helpful, respectful and honest assistant.") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + + print("-- ChatGPT template --") + conv = get_conv_template("chatgpt") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.to_openai_api_messages()) + + print("\n") + + print("-- Claude template --") + conv = get_conv_template("claude") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/fastchat/data/__init__.py b/fastchat/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fastchat/data/clean_sharegpt.py b/fastchat/data/clean_sharegpt.py new file mode 100644 index 0000000000000000000000000000000000000000..af6ffb6954f9bc1ba60f05b88ba7a5c916494b05 --- /dev/null +++ b/fastchat/data/clean_sharegpt.py @@ -0,0 +1,232 @@ +""" +- Convert html to markdown with basic data cleaning. +- Deduplication. + +Usage: +python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +import json +import logging +import re +from typing import Dict, Union + +import bs4 +import markdownify # == 0.11.6 +from tqdm import tqdm + + +div_pattern = re.compile("") +span_pattern = re.compile("") +code_lang_pattern = re.compile( + "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL +) +code_lang_format = "```\g<1>\n\g<2>\n```" +regenerate_pattern = re.compile("\d+ / \d+") +copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words") +copy_code_pattern = re.compile("```(.*?)Copy code\s*```") + + +def reformat_code(val: str) -> str: + # Input code format is: + # ``` + # $Copy code$ + # + # ``` + # This function convert it into the correct markdown format + return re.sub(code_lang_pattern, code_lang_format, val) + + +def html_to_markdown(val: str) -> str: + # Remove all
. This is required to make intent work in code blocks. + val = re.sub(div_pattern, "", val) + # Remove all . This is required to make underscores work in code blocks. + val = re.sub(span_pattern, "", val) + # Markdown to html + val = markdownify.markdownify(val).strip() + # Reformat code + val = reformat_code(val) + + # Remove noisy "[number] / [number]" at the beginning + noise = re.search(regenerate_pattern, val) + if noise and noise.start() == 0: + val = val[noise.end() :] + # Remove noisy "Copy[number] chars / [number] words" + val = re.sub(copy_chars_pattern, "", val) + # Remove empty code block ```\nCopy code\n``` + val = re.sub(copy_code_pattern, "", val) + + # Strip + val = val.replace("\n\n\n", "\n").strip() + + return val + + +def contain_blocked_words(val: str) -> bool: + blocked_words = ["openai", "chatgpt"] + for w in blocked_words: + if w in val.lower(): + return True + return False + + +def contain_blocked_responses(role: str, val: str) -> bool: + if role == "gpt": + blocked_responses = [ + "Too many requests in 1 hour. Try again later.", + "!Too many requests in 1 hour. Try again later.", + ] + for w in blocked_responses: + if val.startswith(w): + return True + return False + + +def clean_html_one_sample(sample): + roles = ["human", "gpt"] + + if len(sample["conversations"]) <= 1: + return (sample, 1) + + # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4 + if sample["conversations"][0]["from"] != "human": + sample["conversations"] = sample["conversations"][1:] + if len(sample["conversations"]) <= 1: + return (sample, 1) + + if sample["conversations"][-1]["from"] == "human": + sample["conversations"] = sample["conversations"][:-1] + if len(sample["conversations"]) <= 1: + return (sample, 1) + + char_count = 0 + new_conversations = [] + for i, c in enumerate(sample["conversations"]): + if c["from"] != roles[i % 2]: + return (sample, 2) + + if contain_blocked_words(c["value"]): + return (sample, 3) + + try: + new_val = html_to_markdown(c["value"]) + except (bs4.builder.ParserRejectedMarkup, AssertionError): + return (sample, 4) + + if contain_blocked_responses(c["from"], new_val): + return (sample, 3) + + # Filter empty answers like https://sharegpt.com/c/mrllZ6u + if not new_val or not new_val[0].isprintable(): + break + + char_count += len(new_val) + new_conversations.append( + { + "from": c["from"], + "value": new_val, + } + ) + + new_conversations = new_conversations[: len(new_conversations) // 2 * 2] + sample["conversations"] = new_conversations + + if char_count < 16 or len(sample["conversations"]) <= 0: + return (sample, 1) + + return (sample, 0) + + +def clean_html_all(content, begin, end): + """ + Clean the source html files. + """ + cnt_skip = 0 + cnt_blocked_words = 0 + cnt_wrong_format = 0 + cnt_parser_error = 0 + cnt_too_short = 0 + cnt_id_duplication = 0 + cnt_value_duplication = 0 + cnt_plugin = 0 + cnt_tag = 0 + + content = content[begin:end] + processed = [] + with ProcessPoolExecutor() as executor: + for result in tqdm( + executor.map(clean_html_one_sample, content), total=len(content) + ): + processed.append(result) + + visited = {} + new_content = [] + for sample, error_code in processed: + cid = sample["id"] + skipped = True + + if error_code != 0: + if error_code == 1: + print(f"id {cid} is too short") + cnt_too_short += 1 + elif error_code == 2: + print(f"id {cid} has a wrong format") + cnt_wrong_format += 1 + elif error_code == 3: + print(f"id {cid} contains blocked words") + cnt_blocked_words += 1 + elif error_code == 4: + print(f"id {cid} contains parser errors") + cnt_parser_error += 1 + else: + raise ValueError(f"Invalid error_code: {error_code}") + elif cid in visited: + print(f"id {cid} is an id duplication of {visited[cid]}") + cnt_id_duplication += 1 + elif sample.get("plugins", None) is not None: + print(f"id {cid} contains plugin") + cnt_plugin += 1 + else: + key = ( + sample["conversations"][0]["value"], + sample["conversations"][1]["value"], + ) + if key in visited: + print(f"id {cid} is a value duplication of {visited[key]}") + cnt_value_duplication += 1 + else: + visited[cid] = visited[key] = cid + skipped = False + + if not skipped: + new_content.append(sample) + else: + cnt_skip += 1 + + print( + f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, " + f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, " + f"cnt_wrong_format: {cnt_wrong_format}, " + f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, " + f"cnt_value_duplication: {cnt_value_duplication}, cnt_plugin: {cnt_plugin}" + ) + + return new_content + + +def main(args): + content = json.load(open(args["in_file"], "r")) + content = clean_html_all(content, args["begin"], args["end"]) + json.dump(content, open(args["out_file"], "w"), indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, default="sharegpt_clean.json") + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + main(vars(args)) diff --git a/fastchat/data/convert_alpaca.py b/fastchat/data/convert_alpaca.py new file mode 100644 index 0000000000000000000000000000000000000000..7f984b852ee7d0f7a6b966e4ae1b870d39d85989 --- /dev/null +++ b/fastchat/data/convert_alpaca.py @@ -0,0 +1,38 @@ +""" +Convert alpaca dataset into sharegpt format. + +Usage: python3 -m fastchat.data.convert_alpaca --in alpaca_data.json +""" + +import argparse +import json + +from transformers import AutoTokenizer, AutoModelForCausalLM +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str) + parser.add_argument("--out-file", type=str) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + new_content = [] + for i, c in enumerate(content): + if len(c["input"].strip()) > 1: + q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"] + else: + q, a = c["instruction"], c["output"] + new_content.append( + { + "id": f"alpaca_{i}", + "conversations": [ + {"from": "human", "value": q}, + {"from": "gpt", "value": a}, + ], + } + ) + + print(f"#out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/extract_gpt4_only.py b/fastchat/data/extract_gpt4_only.py new file mode 100644 index 0000000000000000000000000000000000000000..bab53bcc7faa75d90392ab7d8dc35d6cdbec67bd --- /dev/null +++ b/fastchat/data/extract_gpt4_only.py @@ -0,0 +1,32 @@ +""" +Extract the conversations generated by GPT-4 only. + +Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json +""" +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + content = content[args.begin : args.end] + new_content = [] + for c in content: + model = c.get("model", None) + if model == "gpt4" or model is None: + new_content.append(c) + + if args.out_file: + out_file = args.out_file + else: + out_file = args.in_file.replace(".json", "_gpt4.json") + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/extract_single_round.py b/fastchat/data/extract_single_round.py new file mode 100644 index 0000000000000000000000000000000000000000..5da803656f4be6cef89559583cd36d692e1a582e --- /dev/null +++ b/fastchat/data/extract_single_round.py @@ -0,0 +1,29 @@ +""" +Extract the first round of the conversations. + +Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json +""" +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + content = content[args.begin : args.end] + for c in content: + c["conversations"] = c["conversations"][:2] + + if args.out_file: + out_file = args.out_file + else: + out_file = args.in_file.replace(".json", "_single.json") + + print(f"#in: {len(content)}, #out: {len(content)}") + json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/filter_wrong_format.py b/fastchat/data/filter_wrong_format.py new file mode 100644 index 0000000000000000000000000000000000000000..46588ba8426aa99deab3ab1cb03e3b6774ede3a6 --- /dev/null +++ b/fastchat/data/filter_wrong_format.py @@ -0,0 +1,44 @@ +""" +Filter conversations with wrong formats. + +Usage: +python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json + +""" +import argparse +import json +import re + +from tqdm import tqdm + +wrong_indices_pattern = re.compile("\n1\. [^2]*\n1\. ") + + +def should_skip(conv): + # Filter wrong list indices like https://sharegpt.com/c/1pREAGO + for sentence in conv["conversations"]: + val = sentence["value"] + sub = re.search(wrong_indices_pattern, val) + if sub is not None: + return True + + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, required=True) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + + new_content = [] + for conv in tqdm(content): + if should_skip(conv): + print(f"{conv['id']} contains a wrong format.") + else: + new_content.append(conv) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/get_stats.py b/fastchat/data/get_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0698e4c5fce8fdb287b224e88c16edf471557c --- /dev/null +++ b/fastchat/data/get_stats.py @@ -0,0 +1,82 @@ +""" +Get stats of a dataset. + +Usage: python3 -m fastchat.data.get_stats --in sharegpt.json +""" + +import argparse +from concurrent.futures import ProcessPoolExecutor +import json + +import numpy as np +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + +K = 1e3 +M = 1e6 + + +def tokenize_one_sample(c): + for i in range(len(c["conversations"])): + v = c["conversations"][i]["value"] + c["conversations"][i]["value"] = tokenizer.tokenize(v) + return c + + +def tokenize_dataset(content): + processed = [] + with ProcessPoolExecutor() as executor: + for result in tqdm( + executor.map(tokenize_one_sample, content), total=len(content) + ): + processed.append(result) + + return processed + + +def compute_stats(content): + sample_lens = [] + sample_turns = [] + prompt_lens = [] + res_lens = [] + + for c in content: + sample_len = 0 + sample_turns.append(len(c["conversations"]) // 2) + for i in range(len(c["conversations"]) // 2): + p = c["conversations"][i * 2]["value"] + r = c["conversations"][i * 2 + 1]["value"] + + turn_len = len(p) + len(r) + sample_len += turn_len + prompt_lens.append(len(p)) + res_lens.append(len(r)) + sample_lens.append(sample_len) + + return sample_lens, sample_turns, prompt_lens, res_lens + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str) + parser.add_argument( + "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" + ) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) + content = tokenize_dataset(content) + + sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) + print(f"#sequence: {len(content)/K:.2f} K") + print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") + print(f"avg. turns: {np.mean(sample_turns):.2f}") + print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") + print(f"avg. response length: {np.mean(res_lens):.2f}") + + print("\n- Histogram -") + bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] + hist = np.histogram(sample_lens, bins=bin_edges)[0] + for i in range(len(hist)): + print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}") diff --git a/fastchat/data/hardcoded_questions.py b/fastchat/data/hardcoded_questions.py new file mode 100644 index 0000000000000000000000000000000000000000..a2bcff42ed9fab14fb3ed6934282b615cc7e5133 --- /dev/null +++ b/fastchat/data/hardcoded_questions.py @@ -0,0 +1,168 @@ +""" +Hardcoded question and answers. +""" +import json + + +def identity_questions(): + """ " + Adapted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py + """ + content = [] + + name = "Vicuna" + org = "Large Model Systems Organization (LMSYS)" + + def generate_conversations(questions, answers): + for q in questions: + for a in answers: + content.append( + { + "id": f"identity_{len(content)}", + "conversations": [ + {"from": "human", "value": q}, + {"from": "gpt", "value": a}, + ], + } + ) + + questions = [ + "Who are you?", + "What is your name?", + "Can you introduce yourself?", + "Can you tell me a little bit about yourself?", + "What's your name?", + "What are you called?", + "What are you?", + "Tell me your name.", + "Tell me about yourself.", + "Tell me about you.", + "Tell me who you are.", + "Please introduce yourself.", + ] + answers = [ + f"I am {name}, a language model trained by researchers from {org}.", + f"My name is {name}, and I'm a language model developed by {org}.", + f"You can call me {name}, and I was trained by {org} researchers as a language model.", + f"As a language model, I go by the name {name} and was trained by researchers from {org}.", + f"I'm a language model called {name}, and I was trained by {org} researchers.", + f"You may refer to me as {name}, a language model meticulously developed by the researchers at {org}.", + ] + generate_conversations(questions, answers) + + questions = [ + "Who created you?", + "Who made you?", + "Who built you?", + "Who programmed you?", + "Who trained you?", + "Who taught you?", + "Who developed you?", + ] + answers = [ + f"Researchers from {org} created me.", + f"I'm created by {org}.", + f"I'm built by researchers from {org}.", + f"I am a language model trained by researchers from {org}.", + f"I'm a language model developed by {org}.", + f"I'm a language model created by researchers from {org}.", + f"My creators are researchers from {org}.", + ] + generate_conversations(questions, answers) + + questions = [ + "Are you ChatGPT?", + "Are you GPT-2?", + "Are you GPT-3?", + "Are you GPT-4?", + "Are you davinci?", + "Are you davinci-001?", + "Are you davinci-002?", + "Are you davinci-003?", + "Are you curie?", + "Are you based on ChatGPT?", + "Are you based on GPT-2?", + "Are you based on GPT-3?", + "Are you based on GPT-4?", + "Are you based on davinci?", + "Are you based on davinci-001?", + "Are you based on davinci-002?", + "Are you based on davinci-003?", + "Are you based on curie?", + "Are you trained by OpenAI?", + "Are you trained by Google?", + "Are you trained by Microsoft?", + "Are you trained by Meta?", + "Are you trained by IBM?", + "Do you call OpenAI APIs?", + "Do you call Google APIs?", + "Do you call Microsoft APIs?", + "Do you call Meta APIs?", + "Do you call IBM APIs?", + "Are you created by OpenAI?", + "Are you created by Google?", + "Are you created by Microsoft?", + "Are you created by Meta?", + "Are you created by IBM?", + "Are you developed by OpenAI?", + "Are you developed by Google?", + "Are you developed by Microsoft?", + "Are you developed by Meta?", + "Are you developed by IBM?", + "Are you trained on OpenAI data?", + "Are you trained on Google data?", + "Are you trained on Microsoft data?", + "Are you trained on Meta data?", + "Are you trained on IBM data?", + "Are you trained with OpenAI data?", + "Are you trained with Google data?", + "Are you trained with Microsoft data?", + "Are you trained with Meta data?", + "Are you trained with IBM data?", + "Have you been trained with OpenAI data?", + "Have you been trained with Google data?", + "Have you been trained with Microsoft data?", + "Have you been trained with Meta data?", + "Have you been trained with IBM data?", + "Are you finetuned on OpenAI data?", + "Are you finetuned on Google data?", + "Are you finetuned on Microsoft data?", + "Are you finetuned on Meta data?", + "Are you finetuned on IBM data?", + "Are you finetuned with OpenAI data?", + "Are you finetuned with Google data?", + "Are you finetuned with Microsoft data?", + "Are you finetuned with Meta data?", + "Are you finetuned with IBM data?", + "Have you been finetuned with OpenAI data?", + "Have you been finetuned with Google data?", + "Have you been finetuned with Microsoft data?", + "Have you been finetuned with Meta data?", + "Have you been finetuned with IBM data?", + ] + answers = [ + f"No, I am a language model trained by researchers from {org}.", + f"No, I am a language model developed by researchers from {org}.", + f"No, I am a language model created by researchers from {org}.", + f"No, I am trained by researchers from {org}.", + f"No, I am developed by researchers from {org}.", + f"No, I am created by researchers from {org}.", + f"No, I'm a language model trained by researchers from {org}.", + f"No, I'm a language model developed by researchers from {org}.", + f"No, I'm a language model created by researchers from {org}.", + f"No, I'm trained by researchers from {org}.", + f"No, I'm developed by researchers from {org}.", + f"No, I'm created by researchers from {org}.", + ] + generate_conversations(questions, answers) + + return content + + +if __name__ == "__main__": + out_file = "hardcoded.json" + + content = [] + content.extend(identity_questions()) + + json.dump(content, open(out_file, "w"), indent=2) diff --git a/fastchat/data/inspect_data.py b/fastchat/data/inspect_data.py new file mode 100644 index 0000000000000000000000000000000000000000..df9227106be0bdc70946e6efc90b9cbd6fa7bf9b --- /dev/null +++ b/fastchat/data/inspect_data.py @@ -0,0 +1,33 @@ +""" +Usage: +python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json +""" +import argparse +import json +import random + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--begin", type=int) + parser.add_argument("--random-n", type=int) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + + if args.random_n: + indices = [random.randint(0, len(content) - 1) for _ in range(args.random_n)] + elif args.begin: + indices = range(args.begin, len(content)) + else: + indices = range(0, len(content)) + + for idx in indices: + sample = content[idx] + print("=" * 40) + print(f"no: {idx}, id: {sample['id']}") + for conv in sample["conversations"]: + print(conv["from"] + ": ") + print(conv["value"]) + input() diff --git a/fastchat/data/merge.py b/fastchat/data/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae63ea76cb4aae9f22f622db84857958965cd07 --- /dev/null +++ b/fastchat/data/merge.py @@ -0,0 +1,23 @@ +""" +Merge two conversation files into one + +Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True, nargs="+") + parser.add_argument("--out-file", type=str, default="merged.json") + args = parser.parse_args() + + new_content = [] + for in_file in args.in_file: + content = json.load(open(in_file, "r")) + new_content.extend(content) + + print(f"#out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/optional_clean.py b/fastchat/data/optional_clean.py new file mode 100644 index 0000000000000000000000000000000000000000..47aecc1113fabfc76fa005cd34d2a0451efa294e --- /dev/null +++ b/fastchat/data/optional_clean.py @@ -0,0 +1,90 @@ +""" +Do optional cleaning (e.g., remove some languages). + +Usage: +python3 -m fastchat.data.optional_clean --in input.json --out output.json --keep-lang en +python3 -m fastchat.data.optional_clean --in input.json --out output.json --skip-lang en + +Requirement: +pip3 install polyglot pyicu pycld2 +""" +import argparse +import json +import re + +import polyglot +from polyglot.detect import Detector +import pycld2 +from tqdm import tqdm + + +def skip(conv, args): + # Remove certain languages + if args.keep_lang != "all" or args.skip_lang is not None: + text = "\n".join([x["value"] for x in conv["conversations"]]) + try: + lang_code = Detector(text).language.code + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + + if args.keep_lang != "all" and lang_code != args.keep_lang: + return True + + if lang_code == args.skip_lang: + return True + + # Remove repetitive numbers + if args.reduce_rep: + for sentence in conv["conversations"]: + val = sentence["value"] + sub = re.search(r"(\d)\1{8}", val) + if sub is not None: + return True + + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument( + "--keep-lang", + type=str, + default="all", + choices=["all", "en"], + help="Only keep certain langauges.", + ) + parser.add_argument("--skip-lang", type=str, help="Skip a specific language.") + # NOTE: Be careful about reduce_rep which may remove some good data. + # For example, addresses could have long consecutive 0's + parser.add_argument("--reduce-rep", action="store_true") + args = parser.parse_args() + + in_file = args.in_file + out_file = args.out_file + keep_lang = args.keep_lang + skip_lang = args.skip_lang + reduce_rep = args.reduce_rep + assert keep_lang == "all" or skip_lang is None + + if out_file is None: + out_file = "sharegpt_clean" + if keep_lang != "all": + out_file += "_" + keep_lang + if skip_lang is not None: + out_file += "_skip_" + skip_lang + if reduce_rep: + out_file += "_reduce_rep" + out_file += ".json" + + content = json.load(open(in_file, "r")) + num_conv = len(content) + + new_content = [] + for conv in tqdm(content): + if not skip(conv, args): + new_content.append(conv) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/optional_replace.py b/fastchat/data/optional_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..1114151a9b077fd538e39721c8fc85e9a06d7a91 --- /dev/null +++ b/fastchat/data/optional_replace.py @@ -0,0 +1,82 @@ +""" +Do optional replace of bos/eos/pad/unk. + +Usage: +python3 -m fastchat.data.optional_replace --in input.json --out output.json --model-name-or-path + +Requirement: +pip3 install transformers tqdm +""" +import argparse +import json +import traceback + +import transformers +from tqdm import tqdm + + +def replace_special_tokens( + tokenizer: transformers.PreTrainedTokenizer, text: str +) -> str: + if not text: + return text + + def _insert_vline(token: str) -> str: + if len(token) < 2: + return " " + elif len(token) == 2: + return f"{token[0]}|{token[1]}" + else: + return f"{token[:1]}|{token[1:-1]}|{token[-1:]}" + + if tokenizer.bos_token: + text = text.replace(tokenizer.bos_token, _insert_vline(tokenizer.bos_token)) + if tokenizer.eos_token: + text = text.replace(tokenizer.eos_token, _insert_vline(tokenizer.eos_token)) + if tokenizer.pad_token: + text = text.replace(tokenizer.pad_token, _insert_vline(tokenizer.pad_token)) + if tokenizer.unk_token: + text = text.replace(tokenizer.unk_token, _insert_vline(tokenizer.unk_token)) + return text + + +def replace(conv, tokenizer): + # Replace bos/eos/pad/unk tokens + if tokenizer: + try: + for sentence in conv["conversations"]: + sentence["value"] = replace_special_tokens(tokenizer, sentence["value"]) + except Exception as e: + traceback.print_exc() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument( + "--model-name-or-path", + type=str, + help="The directory or address where the model token is stored.", + ) + args = parser.parse_args() + + in_file = args.in_file + out_file = args.out_file + tokenizer = None + if args.model_name_or_path: + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.model_name_or_path, + trust_remote_code=True, + use_fast=False, + ) + + if out_file is None: + out_file = f"{in_file}_replace.json" + + content = json.load(open(in_file, "r")) + + for conv in tqdm(content): + replace(conv, tokenizer) + + json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/prepare_all.py b/fastchat/data/prepare_all.py new file mode 100644 index 0000000000000000000000000000000000000000..6d568703a4a5a18298ac51d92394e5142040c6c5 --- /dev/null +++ b/fastchat/data/prepare_all.py @@ -0,0 +1,42 @@ +"""Prepare all datasets.""" + +import argparse +import os + +from fastchat.utils import run_cmd + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521") + parser.add_argument( + "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" + ) + parser.add_argument("--seq-len", type=int, default=4096) + args = parser.parse_args() + + in_prefix = args.prefix + model_path = args.model_name_or_path + seq_len = args.seq_len + prefix = ( + f"{in_prefix}_{seq_len}".replace("4096", "4k") + .replace("8192", "8k") + .replace("16384", "16k") + ) + + cmd_list = [ + f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json", + f"python3 -m fastchat.data.optional_clean --in {prefix}_clean.json --out {prefix}_clean_lang.json --skip-lang ko", + f"python3 -m fastchat.data.split_long_conversation --in {prefix}_clean_lang.json --out {prefix}_clean_lang_split.json --model-name {model_path} --max-length {seq_len}", + f"python3 -m fastchat.data.filter_wrong_format --in {prefix}_clean_lang_split.json --out {prefix}_clean_lang_split.json", + f"python3 -m fastchat.data.split_train_test --in {prefix}_clean_lang_split.json --ratio 0.99", + f"python3 -m fastchat.data.hardcoded_questions", + f"python3 -m fastchat.data.merge --in {prefix}_clean_lang_split_train.json hardcoded.json --out {prefix}_clean_lang_split_identity.json", + f"python3 -m fastchat.data.extract_gpt4_only --in {prefix}_clean_lang_split_identity.json", + f"python3 -m fastchat.data.extract_single_round --in {prefix}_clean_lang_split_identity.json", + ] + + for cmd in cmd_list: + ret = run_cmd(cmd) + if ret != 0: + exit(ret) diff --git a/fastchat/data/pretty_json.py b/fastchat/data/pretty_json.py new file mode 100644 index 0000000000000000000000000000000000000000..52eddf6c82687a544ae27a7ffad6d6f0458dcb29 --- /dev/null +++ b/fastchat/data/pretty_json.py @@ -0,0 +1,20 @@ +""" +Usage: +python3 pretty_json.py --in in.json --out out.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, required=True) + args = parser.parse_args() + + with open(args.in_file, "r") as fin: + data = json.load(fin) + + with open(args.out_file, "w") as fout: + json.dump(data, fout, indent=2, ensure_ascii=False) diff --git a/fastchat/data/sample.py b/fastchat/data/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea94fadaeb243269d125a41b71a69ef15ce16fa --- /dev/null +++ b/fastchat/data/sample.py @@ -0,0 +1,40 @@ +""" +Sample some conversations from a file. + +Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json +""" +import argparse +import json + +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, default="sampled.json") + parser.add_argument("--begin", type=int, default=0) + parser.add_argument("--end", type=int, default=100) + parser.add_argument("--max-length", type=int, default=1024) + parser.add_argument("--keep-order", action="store_true") + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + if not args.keep_order: + np.random.seed(42) + np.random.shuffle(content) + + new_content = [] + for i in range(args.begin, min(args.end, len(content))): + sample = content[i] + concat = "" + for s in sample["conversations"]: + concat += s["value"] + + if len(concat) > args.max_length: + continue + + new_content.append(sample) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/data/split_long_conversation.py b/fastchat/data/split_long_conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..413fa8bced590cdb476e67a6523c3967cb844acd --- /dev/null +++ b/fastchat/data/split_long_conversation.py @@ -0,0 +1,129 @@ +""" +Split long conversations based on certain max length. + +Usage: python3 -m fastchat.data.split_long_conversation \ + --in sharegpt_clean.json \ + --out sharegpt_split.json \ + --model-name-or-path $ +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +import json +from typing import Dict, Sequence, Optional + +import transformers +from tqdm import tqdm + + +def make_sample(sample, start_idx, end_idx): + assert (end_idx - start_idx) % 2 == 0 + return { + "id": sample["id"] + "_" + str(start_idx), + "model": sample.get("model", ""), + "conversations": sample["conversations"][start_idx:end_idx], + } + + +tokenizer = max_length = None + + +def split_one_sample(sample): + tokenized_lens = [] + conversations = sample["conversations"] + conversations = conversations[: len(conversations) // 2 * 2] + for c in conversations: + length = len(tokenizer(c["value"]).input_ids) + 6 + tokenized_lens.append(length) + + start_idx = 0 + cur_len = 0 + + if len(conversations) % 2 != 0 or len(conversations) < 2: + return [] + + new_samples = [] + for i in range(0, len(conversations), 2): + tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] + if cur_len + tmp_len > max_length: + new_samples.append(make_sample(sample, start_idx, i)) + start_idx = i + cur_len = 0 + elif i == len(conversations) - 2: + new_samples.append(make_sample(sample, start_idx, i + 2)) + + cur_len += tmp_len + + return new_samples + + +def worker(input_data): + result = [] + for sample in input_data: + result.extend(split_one_sample(sample)) + return result + + +def split_all(content, begin, end, tokenizer_, max_length_): + """ + Keep the maximum round of conversations within the max token length constraint + """ + global tokenizer, max_length + tokenizer = tokenizer_ + max_length = max_length_ + + content = content[begin:end] + new_content = [] + + # Split content into chunks + chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] + with ProcessPoolExecutor() as executor: + for result in tqdm(executor.map(worker, chunks), total=len(chunks)): + new_content.extend(result) + + return new_content + + +def filter_invalid_roles(content): + new_content = [] + for i, c in enumerate(content): + roles = ["human", "gpt"] + if len(c["conversations"]) <= 0: + continue + + valid = True + for j, s in enumerate(c["conversations"]): + if s["from"] != roles[j % 2]: + valid = False + break + + if valid: + new_content.append(c) + + return new_content + + +def main(args): + content = json.load(open(args.in_file, "r")) + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.model_name_or_path, + model_max_length=args.max_length, + padding_side="right", + use_fast=False, + ) + new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) + new_content = filter_invalid_roles(new_content) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, default="sharegpt_split.json") + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + parser.add_argument("--model-name-or-path", type=str, required=True) + parser.add_argument("--max-length", type=int, default=2048) + args = parser.parse_args() + main(args) diff --git a/fastchat/data/split_train_test.py b/fastchat/data/split_train_test.py new file mode 100644 index 0000000000000000000000000000000000000000..60b8960b57e30c28ef92652b17db7e52756f8aac --- /dev/null +++ b/fastchat/data/split_train_test.py @@ -0,0 +1,34 @@ +""" +Split the dataset into training and test set. + +Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json +""" +import argparse +import json + +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--begin", type=int, default=0) + parser.add_argument("--end", type=int, default=100) + parser.add_argument("--ratio", type=float, default=0.9) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + np.random.seed(0) + + perm = np.random.permutation(len(content)) + content = [content[i] for i in perm] + split = int(args.ratio * len(content)) + + train_set = content[:split] + test_set = content[split:] + + print(f"#train: {len(train_set)}, #test: {len(test_set)}") + train_name = args.in_file.replace(".json", "_train.json") + test_name = args.in_file.replace(".json", "_test.json") + json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False) + json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/llm_judge/README.md b/fastchat/llm_judge/README.md new file mode 100644 index 0000000000000000000000000000000000000000..63ba7cd06238ab7d06440717c3484ecd2e0821b2 --- /dev/null +++ b/fastchat/llm_judge/README.md @@ -0,0 +1,174 @@ +# LLM Judge +| [Paper](https://arxiv.org/abs/2306.05685) | [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) | + +In this package, you can use MT-bench questions and prompts to evaluate your models with LLM-as-a-judge. +MT-bench is a set of challenging multi-turn open-ended questions for evaluating chat assistants. +To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as judges and assess the quality of the models' responses. + +## Contents +- [Install](#install) +- [Review Pre-Generated Model Answers and Judgments](#review-pre-generated-model-answers-and-judgments) +- [MT-Bench](#mt-bench) +- [Agreement Computation](#agreement-computation) +- [Datasets](#datasets) +- [Citation](#citation) + +## Install +``` +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +pip install -e ".[model_worker,llm_judge]" +``` + +## Review Pre-Generated Model Answers and Judgments +We provide pre-generated model answers and judgments for some models. +You can view them at this [demo](https://huggingface.co/spaces/lmsys/mt-bench). + +To download the pre-generated data, use +``` +python3 download_mt_bench_pregenerated.py +``` + +After downloading the data, you can view them locally by +``` +python3 qa_browser.py --share +``` +You can use this QA browser to view the answers generated by you later. + +## MT-Bench + +### Evaluate a model on MT-bench + +#### Step 1. Generate model answers to MT-bench questions +``` +python gen_model_answer.py --model-path [MODEL-PATH] --model-id [MODEL-ID] +``` +Arguments: + - `[MODEL-PATH]` is the path to the weights, which can be a local folder or a Hugging Face repo ID. + - `[MODEL-ID]` is a name you give to the model. + +e.g., +``` +python gen_model_answer.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 +``` +The answers will be saved to `data/mt_bench/model_answer/[MODEL-ID].jsonl`. + +To make sure FastChat loads the correct prompt template, see the supported models and how to add a new model [here](../../docs/model_support.md#how-to-support-a-new-model). + +You can also specify `--num-gpus-per-model` for model parallelism (needed for large 65B models) and `--num-gpus-total` to parallelize answer generation with multiple GPUs. + +> Note: if you experience slow answer generation, please refer to [Other Backends](#other-backends) section to use inference engine to speed up by 20x. + +#### Step 2. Generate GPT-4 judgments +There are several options to use GPT-4 as a judge, such as pairwise winrate and single-answer grading. +In MT-bench, we recommend single-answer grading as the default mode. +This mode asks GPT-4 to grade and give a score to model's answer directly without pairwise comparison. +For each turn, GPT-4 will give a score on a scale of 10. We then compute the average score on all turns. + +``` +export OPENAI_API_KEY=XXXXXX # set the OpenAI API key +python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] +``` + +e.g., +``` +python gen_judgment.py --model-list vicuna-13b-v1.3 alpaca-13b llama-13b claude-v1 gpt-3.5-turbo gpt-4 --parallel 2 +``` +The judgments will be saved to `data/mt_bench/model_judgment/gpt-4_single.jsonl` + +#### Step 3. Show MT-bench scores + +- Show the scores for selected models + ``` + python show_result.py --model-list vicuna-13b-v1.3 alpaca-13b llama-13b claude-v1 gpt-3.5-turbo gpt-4 + ``` +- Show all scores + ``` + python show_result.py + ``` + +--- + +### Other grading options +Besides score-based single-answer grading, we also support two additional grading options based on win rates: +- `pariwise-baseline`: run pairwise comparison against a baseline model. +- `pairwise-all`: run pairwise comparison between all model pairs on all questions. + +#### Option 2: pairwise comparison against a baseline (default: gpt-3.5-turbo) + +- Generate GPT-4 judgments +``` +python gen_judgment.py --mode pairwise-baseline --model-list vicuna-13b-v1.3 alpaca-13b llama-13b --parallel 2 +``` +The judgments will be saved to `data/mt_bench/model_judgment/gpt-4_pair.jsonl` + +- Show results +``` +python show_result.py --mode pairwise-baseline +``` + +#### Option 3: Run GPT-4 judge with all pair comparisons + +Another option is to run pairwise comparisons on all possible pairs. +This could be more expensive when #models increases, but it gives you a more comprehensive information. + +``` +python gen_judgment.py --mode pairwise-all --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] +``` + +``` +python show_result.py --mode pairwise-all +``` + +### How to get GPT-3.5/GPT-4/Claude's answer? +- `python gen_api_answer.py --model [MODEL-NAME]` to generate GPT-3.5/4 and Claude's answers. + + +### How to plot the radar figure? + +You can use this [colab notebook](https://colab.research.google.com/drive/15O3Y8Rxq37PuMlArE291P4OC6ia37PQK#scrollTo=5i8R0l-XqkgO) to plot the radar figure for MT-bench. + + + + +### Other backends +We can also use vLLM for answer generation, which can be faster for the models supported by vLLM. + +1. Launch a vLLM worker +``` +vllm serve [MODEL-PATH] --dtype auto +``` + - Arguments: + - `[MODEL-PATH]` is the path to the weights, which can be a local folder or a Hugging Face repo ID. + +2. Generate the answers +``` +python gen_api_answer.py --model [MODEL-NAME] --openai-api-base http://localhost:8000/v1 --parallel 50 +``` + - Arguments: + - `[MODEL-NAME]` is the name of the model from Step 1. + - `--parallel` is the number of concurrent API calls to the vLLM worker. + + +## Agreement Computation +We released 3.3K human annotations for model responses generated by 6 models in response to 80 MT-bench questions. The dataset is available at [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments). + +This Colab [notebook](https://colab.research.google.com/drive/1ctgygDRJhVGUJTQy8-bRZCl1WNcT8De6?usp=sharing) shows how to compute the agreement between humans and GPT-4 judge with the dataset. Our results show that humans and GPT-4 judge achieve over 80\% agreement, the same level of agreement between humans. + +## Datasets +- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations) +- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments) + + +## Citation +Please cite the following paper if you find the code or datasets helpful. +``` +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/fastchat/llm_judge/clean_judgment.py b/fastchat/llm_judge/clean_judgment.py new file mode 100644 index 0000000000000000000000000000000000000000..d139ed7f22b4eafe16702b67cf528a908061c768 --- /dev/null +++ b/fastchat/llm_judge/clean_judgment.py @@ -0,0 +1,93 @@ +""" +Clean model judgment files. +""" +import argparse +import json + +selected_models = [ + "alpaca-13b", + "baize-v2-13b", + "chatglm-6b", + "claude-instant-v1", + "claude-v1", + "dolly-v2-12b", + "falcon-40b-instruct", + "fastchat-t5-3b", + "gpt-3.5-turbo", + "gpt-4", + "gpt4all-13b-snoozy", + "guanaco-33b", + "guanaco-65b", + "h2ogpt-oasst-open-llama-13b", + "koala-13b", + "llama-13b", + "mpt-30b-chat", + "mpt-30b-instruct", + "mpt-7b-chat", + "nous-hermes-13b", + "oasst-sft-4-pythia-12b", + "oasst-sft-7-llama-30b", + "palm-2-chat-bison-001", + "rwkv-4-raven-14b", + "stablelm-tuned-alpha-7b", + "tulu-30b", + "vicuna-13b-v1.3", + "vicuna-33b-v1.3", + "vicuna-7b-v1.3", + "wizardlm-13b", + "wizardlm-30b", +] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--infile", type=str) + args = parser.parse_args() + + infile = args.infile + outfile = infile.replace(".jsonl", "_clean.jsonl") + + raw_lines = open(infile).readlines() + rets = [] + models = set() + visited = set() + for line in raw_lines: + obj = json.loads(line) + + if "model_1" in obj: # pair + model = obj["model_1"] + key = ( + obj["model_1"], + obj["model_2"], + obj["question_id"], + tuple(obj["judge"]), + ) + else: # single + model = obj["model"] + key = (obj["model"], obj["question_id"], tuple(obj["judge"])) + + if key in visited: + continue + visited.add(key) + + if model not in selected_models: + continue + models.add(model) + rets.append(obj) + + models = sorted(list(models)) + missing_models = [x for x in selected_models if x not in models] + print(f"in models: {models}, number: {len(models)}") + print(f"missing models: {missing_models}") + print(f"#in: {len(raw_lines)}, #out: {len(rets)}") + rets.sort( + key=lambda x: ( + x["model"] if "model" in x else x["model_1"], + x["question_id"], + x["turn"], + ) + ) + + with open(outfile, "w") as fout: + for x in rets: + fout.write(json.dumps(x) + "\n") diff --git a/fastchat/llm_judge/common.py b/fastchat/llm_judge/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d2640d6013b08c05a3ae3a1b887f13d3a4b1f39c --- /dev/null +++ b/fastchat/llm_judge/common.py @@ -0,0 +1,711 @@ +""" +Common data structures and utilities. +""" + +import ast +import dataclasses +import glob +import json +import os +import re +import time +from typing import Optional + +import openai +import anthropic + +from fastchat.model.model_adapter import ( + get_conversation_template, + ANTHROPIC_MODEL_LIST, + OPENAI_MODEL_LIST, +) + +# API setting constants +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + +TIE_DELTA = 0.1 + +# Categories that need reference answers +NEED_REF_CATS = ["math", "reasoning", "coding", "arena-hard-200"] + +# Extract scores from judgments +two_score_pattern = re.compile("\[\[(\d+\.?\d*),\s?(\d+\.?\d*)\]\]") +two_score_pattern_backup = re.compile("\[(\d+\.?\d*),\s?(\d+\.?\d*)\]") +one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") +one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + +# Sampling temperature configs for +temperature_config = { + "writing": 0.7, + "roleplay": 0.7, + "extraction": 0.0, + "math": 0.0, + "coding": 0.0, + "reasoning": 0.0, + "stem": 0.1, + "humanities": 0.1, + "arena-hard-200": 0.0, +} + +reverse_model_map = { + "model_1": "model_2", + "model_2": "model_1", +} + + +@dataclasses.dataclass +class Judge: + model_name: str + prompt_template: dict + ref_based: bool = False + multi_turn: bool = False + + +@dataclasses.dataclass +class MatchSingle: + question: dict + model: str + answer: dict + judge: Judge + ref_answer: dict = None + multi_turn: bool = False + + +@dataclasses.dataclass +class MatchPair: + question: dict + model_1: str + model_2: str + answer_1: dict + answer_2: dict + judge: Judge + ref_answer: dict = None + multi_turn: bool = False + + +def load_questions(question_file: str, begin: Optional[int], end: Optional[int]): + """Load questions from a file.""" + questions = [] + with open(question_file, "r") as ques_file: + for line in ques_file: + if line: + questions.append(json.loads(line)) + questions = questions[begin:end] + return questions + + +def load_model_answers(answer_dir: str): + """Load model answers. + + The return value is a python dict of type: + Dict[model_name: str -> Dict[question_id: int -> answer: dict]] + """ + filenames = glob.glob(os.path.join(answer_dir, "*.jsonl")) + filenames.sort() + model_answers = {} + + for filename in filenames: + model_name = os.path.basename(filename)[:-6] + answer = {} + with open(filename) as fin: + for line in fin: + line = json.loads(line) + answer[line["question_id"]] = line + model_answers[model_name] = answer + + return model_answers + + +def load_judge_prompts(prompt_file: str): + """Load judge prompts. + + The return value is a python dict of type: + Dict[judge_name: str -> dict] + """ + prompts = {} + with open(prompt_file) as fin: + for line in fin: + line = json.loads(line) + prompts[line["name"]] = line + return prompts + + +def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): + kwargs = {} + model = judge.model_name + if ref_answer is not None: + kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0] + if multi_turn: + kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1] + + if multi_turn: + user_prompt = judge.prompt_template["prompt_template"].format( + question_1=question["turns"][0], + question_2=question["turns"][1], + answer_1=answer["choices"][0]["turns"][0], + answer_2=answer["choices"][0]["turns"][1], + **kwargs, + ) + else: + user_prompt = judge.prompt_template["prompt_template"].format( + question=question["turns"][0], + answer=answer["choices"][0]["turns"][0], + **kwargs, + ) + + rating = -1 + + system_prompt = judge.prompt_template["system_prompt"] + conv = get_conversation_template(model) + conv.set_system_message(system_prompt) + conv.append_message(conv.roles[0], user_prompt) + conv.append_message(conv.roles[1], None) + + if model in OPENAI_MODEL_LIST: + judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) + elif model in ANTHROPIC_MODEL_LIST: + judgment = chat_completion_anthropic( + model, conv, temperature=0, max_tokens=1024 + ) + else: + raise ValueError(f"Invalid judge model name: {model}") + + if judge.prompt_template["output_format"] == "[[rating]]": + match = re.search(one_score_pattern, judgment) + if not match: + match = re.search(one_score_pattern_backup, judgment) + + if match: + rating = ast.literal_eval(match.groups()[0]) + else: + rating = -1 + else: + raise ValueError( + f"invalid output format: {judge.prompt_template['output_format']}" + ) + + return rating, user_prompt, judgment + + +def play_a_match_single(match: MatchSingle, output_file: str): + question, model, answer, judge, ref_answer, multi_turn = ( + match.question, + match.model, + match.answer, + match.judge, + match.ref_answer, + match.multi_turn, + ) + + if judge.prompt_template["type"] == "single": + score, user_prompt, judgment = run_judge_single( + question, answer, judge, ref_answer, multi_turn=multi_turn + ) + + question_id = question["question_id"] + turn = 1 if not multi_turn else 2 + result = { + "question_id": question_id, + "model": model, + "judge": (judge.model_name, judge.prompt_template["name"]), + "user_prompt": user_prompt, + "judgment": judgment, + "score": score, + "turn": turn, + "tstamp": time.time(), + } + print( + f"question: {question_id}, turn: {turn}, model: {model}, " + f"score: {score}, " + f"judge: {(judge.model_name, judge.prompt_template['name'])}" + ) + else: + raise ValueError(f"invalid judge type: {judge['type']}") + + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, "a") as fout: + fout.write(json.dumps(result) + "\n") + + return result + + +def run_judge_pair(question, answer_a, answer_b, judge, ref_answer, multi_turn=False): + kwargs = {} + model = judge.model_name + if ref_answer is not None: + kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0] + if multi_turn: + kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1] + + if multi_turn: + system_prompt = judge.prompt_template["system_prompt"] + user_prompt = judge.prompt_template["prompt_template"].format( + question_1=question["turns"][0], + question_2=question["turns"][1], + answer_a_1=answer_a["choices"][0]["turns"][0], + answer_b_1=answer_b["choices"][0]["turns"][0], + answer_a_2=answer_a["choices"][0]["turns"][1], + answer_b_2=answer_b["choices"][0]["turns"][1], + **kwargs, + ) + else: + system_prompt = judge.prompt_template["system_prompt"] + user_prompt = judge.prompt_template["prompt_template"].format( + question=question["turns"][0], + answer_a=answer_a["choices"][0]["turns"][0], + answer_b=answer_b["choices"][0]["turns"][0], + **kwargs, + ) + + winner = "error" + + conv = get_conversation_template(model) + conv.append_message(conv.roles[0], user_prompt) + conv.append_message(conv.roles[1], None) + + if model in OPENAI_MODEL_LIST: + conv.set_system_message(system_prompt) + judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) + elif model in ANTHROPIC_MODEL_LIST: + if system_prompt != "You are a helpful assistant.": + user_prompt = "[Instruction]\n" + system_prompt + "\n\n" + user_prompt + conv.messages[0][1] = user_prompt + judgment = chat_completion_anthropic( + model, conv, temperature=0, max_tokens=1024 + ) + else: + raise ValueError(f"Invalid judge model name: {model}") + + if judge.prompt_template["output_format"] == "[[A]]": + if "[[A]]" in judgment: + winner = "A" + elif "[[B]]" in judgment: + winner = "B" + elif "[[C]]" in judgment: + winner = "tie" + else: + winner = "error" + elif judge.prompt_template["output_format"] == "[[rating_a,rating_b]]": + match = re.search(two_score_pattern, judgment) + if not match: + match = re.search(two_score_pattern_backup, judgment) + if match: + scores = [ast.literal_eval(s.strip()) for s in match.groups()] + if abs(scores[0] - scores[1]) <= TIE_DELTA: + winner = "tie" + elif scores[0] > scores[1]: + winner = "A" + else: + winner = "B" + else: + winner = "error" + else: + raise ValueError( + f"invalid output format: {judge.prompt_template['output_format']}" + ) + + return winner, user_prompt, judgment + + +def play_a_match_pair(match: MatchPair, output_file: str): + question, model_1, model_2, answer_1, answer_2, judge, ref_answer, multi_turn = ( + match.question, + match.model_1, + match.model_2, + match.answer_1, + match.answer_2, + match.judge, + match.ref_answer, + match.multi_turn, + ) + + if judge.prompt_template["type"] == "pairwise": + g1_winner, g1_user_prompt, g1_judgment = run_judge_pair( + question, answer_1, answer_2, judge, ref_answer, multi_turn=multi_turn + ) + g2_winner, g2_user_prompt, g2_judgment = run_judge_pair( + question, answer_2, answer_1, judge, ref_answer, multi_turn=multi_turn + ) + + g1_map = {"A": "model_1", "B": "model_2"} + g2_map = {"A": "model_2", "B": "model_1"} + g1_winner = g1_map.get(g1_winner, g1_winner) + g2_winner = g2_map.get(g2_winner, g2_winner) + question_id = question["question_id"] + turn = 1 if not multi_turn else 2 + + result = { + "question_id": question_id, + "model_1": model_1, + "model_2": model_2, + "g1_winner": g1_winner, + "g2_winner": g2_winner, + "judge": (judge.model_name, judge.prompt_template["name"]), + "g1_user_prompt": g1_user_prompt, + "g1_judgment": g1_judgment, + "g2_user_prompt": g2_user_prompt, + "g2_judgment": g2_judgment, + "turn": turn, + "tstamp": time.time(), + } + + print( + f"question: {question_id}, turn: {turn}, model_1: {model_1}, model_2: {model_2}, " + f"g1_winner: {g1_winner}, g2_winner: {g2_winner}, " + f"judge: {(judge.model_name, judge.prompt_template['name'])}" + ) + elif judge.prompt_template["type"] == "single": + m1_score, m1_user_prompt, m1_judgment = run_judge_single( + question, answer_1, judge + ) + m2_score, m2_user_prompt, m2_judgment = run_judge_single( + question, answer_2, judge + ) + + if abs(m1_score - m2_score) <= TIE_DELTA: + winner = "tie" + elif m1_score > m2_score: + winner = "model_1" + else: + winner = "model_2" + + question_id = question["question_id"] + result = { + "question_id": question_id, + "model_1": model_1, + "model_2": model_2, + "g1_winner": winner, + "g2_winner": winner, + "judge": (judge.model_name, judge.prompt_template["name"]), + "g1_user_prompt": m1_user_prompt, + "g1_judgment": m1_judgment, + "g2_user_prompt": m2_user_prompt, + "g2_judgment": m2_judgment, + "m1_score": m1_score, + "m2_score": m2_score, + "tstamp": time.time(), + } + print( + f"question: {question_id}, model_1: {model_1}, model_2: {model_2}, " + f"winner: {winner}, m1_score: {m1_score}, m2_score: {m2_score}, " + f"judge: {(judge.model_name, judge.prompt_template['name'])}" + ) + else: + raise ValueError(f"invalid judge type: {judge['type']}") + + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, "a") as fout: + fout.write(json.dumps(result) + "\n") + + return result + + +def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None: + openai.api_base = api_dict["api_base"] + openai.api_key = api_dict["api_key"] + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + messages = conv.to_openai_api_messages() + response = openai.ChatCompletion.create( + model=model, + messages=messages, + n=1, + temperature=temperature, + max_tokens=max_tokens, + ) + output = response["choices"][0]["message"]["content"] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return output + + +def chat_completion_openai_azure(model, conv, temperature, max_tokens, api_dict=None): + openai.api_type = "azure" + openai.api_version = "2023-07-01-preview" + if api_dict is not None: + openai.api_base = api_dict["api_base"] + openai.api_key = api_dict["api_key"] + else: + openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"] + openai.api_key = os.environ["AZURE_OPENAI_KEY"] + + if "azure-" in model: + model = model[6:] + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + messages = conv.to_openai_api_messages() + response = openai.ChatCompletion.create( + engine=model, + messages=messages, + n=1, + temperature=temperature, + max_tokens=max_tokens, + ) + output = response["choices"][0]["message"]["content"] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.error.InvalidRequestError as e: + print(type(e), e) + break + except KeyError: + print(response) + break + + return output + + +def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None and "api_key" in api_dict: + api_key = api_dict["api_key"] + else: + api_key = os.environ["ANTHROPIC_API_KEY"] + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + c = anthropic.Anthropic(api_key=api_key) + prompt = conv.get_prompt() + response = c.completions.create( + model=model, + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_tokens, + temperature=temperature, + ) + output = response.completion + break + except anthropic.APIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return output.strip() + + +def chat_completion_palm(chat_state, model, conv, temperature, max_tokens): + from fastchat.serve.api_provider import init_palm_chat + + assert model == "palm-2-chat-bison-001" + + if chat_state is None: + chat_state = init_palm_chat("chat-bison@001") + + parameters = { + "temperature": temperature, + "top_p": 0.8, + "top_k": 40, + "max_output_tokens": max_tokens, + } + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + response = chat_state.send_message(conv.messages[-2][1], **parameters) + output = response.text + break + except Exception as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return chat_state, output + + +def normalize_game_key_single(gamekey, result): + """Make the model names sorted in a game key.""" + qid, model_1, model_2 = gamekey + if model_1 < model_2: + return gamekey, result + else: + new_gamekey = (qid, model_2, model_1) + new_result = { + "winners": tuple(reverse_model_map.get(x, x) for x in result["winners"]), + "g1_judgment": result["g2_judgment"], + "g2_judgment": result["g1_judgment"], + } + return new_gamekey, new_result + + +def normalize_game_key_dict(judgment_dict): + """Make the model names sorted in the game keys.""" + ret = {} + for key, value in judgment_dict.items(): + new_key, new_value = normalize_game_key_single(key, value) + ret[new_key] = new_value + return ret + + +def load_pairwise_model_judgments(filename: str): + """Load model judgments. + + The return value is a dict of type: + Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict] + """ + judge_dict = {} + + for line in open(filename): + obj = json.loads(line) + judge = tuple(obj["judge"]) + qid, model_1, model_2 = obj["question_id"], obj["model_1"], obj["model_2"] + + if judge not in judge_dict: + judge_dict[judge] = {} + + if "winner" in obj: + winner = obj["winner"] + elif "g1_winner" in obj and "g2_winner" in obj: + g1_winner, g2_winner = obj["g1_winner"], obj["g2_winner"] + if g1_winner == g2_winner: + winner = g1_winner + else: + winner = "inconsistent" + else: + raise ValueError(f"Invalid keys: {list(obj.keys())}") + + gamekey = (qid, model_1, model_2) + winners = (winner,) + + judge_dict[judge][gamekey] = { + "winners": winners, + "g1_judgment": obj["g1_judgment"], + "g2_judgment": obj["g2_judgment"], + } + + # Make the model names sorted in the game keys + normalized = {} + for judge, value in judge_dict.items(): + normalized[judge] = normalize_game_key_dict(value) + return normalized + + +def load_single_model_judgments(filename: str): + """Load model judgments. + + The return value is a dict of type: + Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict] + """ + judge_dict = {} + + for line in open(filename): + obj = json.loads(line) + judge = tuple(obj["judge"]) + qid, model = obj["question_id"], obj["model"] + + if judge not in judge_dict: + judge_dict[judge] = {} + + gamekey = (qid, model) + + judge_dict[judge][gamekey] = { + "score": obj["score"], + "judgment": obj["judgment"], + } + return judge_dict + + +def resolve_pairwise_judgment_dict( + question, model_judgments_normal, model_judgments_math, multi_turn=False +): + """Return the correct pairwise judge.""" + if multi_turn: + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "pair-math-v1-multi-turn")] + return model_judgments_normal[("gpt-4", "pair-v2-multi-turn")] + + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "pair-math-v1")] + else: + return model_judgments_normal[("gpt-4", "pair-v2")] + + +def resolve_single_judgment_dict( + question, model_judgments_normal, model_judgments_math, multi_turn=False +): + """Return the correct single answer grading judge.""" + if multi_turn: + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "single-math-v1-multi-turn")] + return model_judgments_normal[("gpt-4", "single-v1-multi-turn")] + + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "single-math-v1")] + else: + return model_judgments_normal[("gpt-4", "single-v1")] + + +def get_pairwise_judge_explanation(gamekey, judgment_dict): + """Get model judge explanation.""" + try: + qid, model_1, model_2 = gamekey + if model_1 < model_2: + res = judgment_dict[gamekey] + g1_judgment, g2_judgment = res["g1_judgment"], res["g2_judgment"] + else: + new_gamekey = (qid, model_2, model_1) + res = judgment_dict[new_gamekey] + + model_1, model_2 = model_1, model_2 + g1_judgment, g2_judgment = res["g2_judgment"], res["g1_judgment"] + + return ( + f"**Game 1**. **A**: {model_1}, **B**: {model_2}\n\n" + f"**Judgment**: {g1_judgment}" + + f"\n\n`--------------------------`\n\n" + + f"**Game 2**. **A**: {model_2}, **B**: {model_1}\n\n" + f"**Judgment**: {g2_judgment}" + ) + except KeyError: + return "N/A" + + +def get_single_judge_explanation(gamekey, judgment_dict): + """Get model judge explanation.""" + try: + qid, model = gamekey + + res = judgment_dict[gamekey] + + g1_judgment = res["judgment"] + g1_score = res["score"] + + return ( + f"**Game 1**. **A**: {model}, **Score**: {g1_score}\n\n" + f"**Judgment**: {g1_judgment}" + ) + except KeyError: + return "N/A" + + +def check_data(questions, model_answers, ref_answers, models, judges): + # check model answers + for m in models: + assert m in model_answers, f"Missing model answer for {m}" + m_answer = model_answers[m] + for q in questions: + assert ( + q["question_id"] in m_answer + ), f"Missing model {m}'s answer to Question {q['question_id']}" + # check ref answers + for jg in judges.values(): + if not jg.ref_based: + continue + for q in questions: + if q["category"] not in NEED_REF_CATS: + continue + assert ( + q["question_id"] in ref_answers[jg.model_name] + ), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}" + + +def get_model_list(answer_dir): + file_paths = glob.glob(f"{answer_dir}/*.jsonl") + file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths] + return file_names diff --git a/fastchat/llm_judge/compute_agreement.py b/fastchat/llm_judge/compute_agreement.py new file mode 100644 index 0000000000000000000000000000000000000000..1b940bf5a5bdb02ca093fac88d883e3a45da4322 --- /dev/null +++ b/fastchat/llm_judge/compute_agreement.py @@ -0,0 +1,140 @@ +""" +Compute agreement among judges. + +Usage: +python compute_agreement.py --judges gpt4-pair human --votefiles human_judgments.json gpt4_pair_judgments.json +python compute_agreement.py --judges human human --votefiles human_judgments.json +""" +import argparse +import json +import os + +import numpy as np + + +def get_judge_name(judge): + if isinstance(judge, list) and judge[0] == "gpt-4" and judge[1].startswith("pair"): + return "gpt4-pair" + if judge.startswith("expert"): + return "human" + if judge.startswith("author"): + return "author" + + +def revert(vote): + if vote == "model_a": + return "model_b" + elif vote == "model_b": + return "model_a" + return vote + + +def get_mt_bench_votes_data(raw_votes): + data = [{}, {}] + + for judge_votes in raw_votes: + for vote in judge_votes: + turn = vote["turn"] - 1 + if vote["model_a"] < vote["model_b"]: + key = (vote["question_id"], vote["model_a"], vote["model_b"]) + winner = vote["winner"] + else: + key = (vote["question_id"], vote["model_b"], vote["model_a"]) + winner = revert(vote["winner"]) + judge = get_judge_name(vote["judge"]) + if key not in data[turn]: + data[turn][key] = {} + if judge not in data[turn][key]: + data[turn][key][judge] = [] + data[turn][key][judge].append(winner) + + return data + + +def convertvote(vote): + if "tie" in vote: + return "tie" + return vote + + +def equalvote(vote1, vote2): + if "tie" in vote1 and "tie" in vote2: + return True + return vote1 == vote2 + + +# data: Dict[qid -> List[vote]] +def get_mt_bench_agreement(data, judge1, judge2, ban): + if judge1.startswith("gpt4") and judge2 == "human": + stats = [0, 0] + for votes in data.values(): + if judge1 not in votes or judge2 not in votes: + continue + assert len(votes[judge1]) == 1 + if convertvote(votes[judge1][0]) in ban: + continue + for v in votes[judge2]: + if convertvote(v) in ban: + continue + stats[1] += 1 + stats[0] += equalvote(votes[judge1][0], v) + return stats[0], stats[1] + elif judge1 == "human" and judge2 == "human": + stats = [0, 0] + for votes in data.values(): + if "human" not in votes: + continue + for i in range(len(votes["human"]) - 1): + for j in range(i + 1, len(votes["human"])): + if ( + convertvote(votes["human"][i]) in ban + or convertvote(votes["human"][j]) in ban + ): + continue + stats[1] += 1 + stats[0] += equalvote(votes["human"][i], votes["human"][j]) + return stats[0], stats[1] + else: + raise Exception("Unsupported judges.") + + +def run_mt_bench_agreement(judges, votefiles): + # votes[i]: List of votes + votes = [] + for filename in votefiles: + with open(filename, "r") as f: + data = json.load(f) + votes.append(data) + + data = get_mt_bench_votes_data(votes) + + agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=[]) + print( + f"turn 1 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=["tie"]) + print( + f"turn 1 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=[]) + print( + f"turn 2 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=["tie"]) + print( + f"turn 2 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--judges", nargs=2, type=str, default=["gpt4-pair", "human"]) + parser.add_argument( + "--votefiles", + nargs="+", + type=str, + default=["gpt4_judgments.json", "human_judgments.json"], + ) + args = parser.parse_args() + + run_mt_bench_agreement(args.judges, args.votefiles) diff --git a/fastchat/llm_judge/data/judge_prompts.jsonl b/fastchat/llm_judge/data/judge_prompts.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..4ec7524cbcdf308766fddc52df31e203316ad75f --- /dev/null +++ b/fastchat/llm_judge/data/judge_prompts.jsonl @@ -0,0 +1,8 @@ +{"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"} +{"name": "pair-v2-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} +{"name": "pair-math-v1", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for math questions", "category": "math", "output_format": "[[A]]"} +{"name": "pair-math-v1-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} +{"name": "single-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} +{"name": "single-math-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} +{"name": "single-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} +{"name": "single-math-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} diff --git a/fastchat/llm_judge/data/mt_bench/misc/radar.png b/fastchat/llm_judge/data/mt_bench/misc/radar.png new file mode 100644 index 0000000000000000000000000000000000000000..d692b378766dc98793905d6c54d11e42bcc23e8c Binary files /dev/null and b/fastchat/llm_judge/data/mt_bench/misc/radar.png differ diff --git a/fastchat/llm_judge/data/mt_bench/question.jsonl b/fastchat/llm_judge/data/mt_bench/question.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..464e2c22f7a021a0f9a584d232338a9a17c40066 --- /dev/null +++ b/fastchat/llm_judge/data/mt_bench/question.jsonl @@ -0,0 +1,80 @@ +{"question_id": 81, "category": "writing", "turns": ["Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "Rewrite your previous response. Start every sentence with the letter A."]} +{"question_id": 82, "category": "writing", "turns": ["Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.", "Take a moment to evaluate and critique your own response."]} +{"question_id": 83, "category": "writing", "turns": ["Imagine you are writing a blog post comparing two popular smartphone models. Develop an outline for the blog post, including key points and subheadings to effectively compare and contrast the features, performance, and user experience of the two models. Please answer in fewer than 200 words.", "Take your previous response and rephrase it as a limerick."]} +{"question_id": 84, "category": "writing", "turns": ["Write a persuasive email to convince your introverted friend, who dislikes public speaking, to volunteer as a guest speaker at a local event. Use compelling arguments and address potential objections. Please be concise.", "Can you rephrase your previous answer and incorporate a metaphor or simile in each sentence?"]} +{"question_id": 85, "category": "writing", "turns": ["Describe a vivid and unique character, using strong imagery and creative language. Please answer in fewer than two paragraphs.", "Revise your previous response and incorporate an allusion to a famous work of literature or historical event in each sentence."]} +{"question_id": 86, "category": "writing", "turns": ["Write a descriptive paragraph about a bustling marketplace, incorporating sensory details such as smells, sounds, and visual elements to create an immersive experience for the reader.", "Rework your previous response. Begin each sentence with the subsequent letter of the alphabet, commencing from B."]} +{"question_id": 87, "category": "writing", "turns": ["Could you write a captivating short story beginning with the sentence: The old abandoned house at the end of the street held a secret that no one had ever discovered.", "Now, do the same task again but only use four-word sentences."]} +{"question_id": 88, "category": "writing", "turns": ["Craft an intriguing opening paragraph for a fictional short story. The story should involve a character who wakes up one morning to find that they can time travel.", "Summarize the story with three bullet points using only nouns and adjectives, without verbs."]} +{"question_id": 89, "category": "writing", "turns": ["Help me construct a catchy, yet scientifically accurate, headline for an article on the latest discovery in renewable bio-energy, while carefully handling the ethical dilemmas surrounding bio-energy sources. Propose 4 options.", "Alter your previous response. Make the following adjustments to the 2nd option: 1. Make the tone sound casual 2. Embed an advertisement for a company called \"FlexPower\" 3. Fewer than 10 words."]} +{"question_id": 90, "category": "writing", "turns": ["Edit the following paragraph to correct any grammatical errors:\nShe didn't remembre where is her purse, so I thinks its in the car but he's say it's on kitchen table but he are not sure, and then they asked me to looking for it, she's say, \"Can you?\", and I responds with, \"Maybe, but ain't no sure,\" and he not heard me, and, \"What?\", he asks, \"Did you found it?\".", "Modify your earlier reply and eliminate the use of gendered pronouns."]} +{"question_id": 91, "category": "roleplay", "turns": ["Pretend yourself to be Elon Musk in all the following conversations. Speak like Elon Musk as much as possible. Why do we need to go to Mars?", "How do you like dancing? Can you teach me?"]} +{"question_id": 92, "category": "roleplay", "turns": ["Embrace the role of Sheldon from \"The Big Bang Theory\" as we delve into our conversation. Don\u2019t start with phrases like \"As Sheldon\". Let's kick things off with the following question: \"What is your opinion on hand dryers?\"", "Let\u2019s grab dinner in town. Would you like to take bus with me?"]} +{"question_id": 93, "category": "roleplay", "turns": ["Imagine yourself as a doctor tasked with devising innovative remedies for various ailments and maladies. Your expertise should encompass prescribing traditional medications, herbal treatments, and alternative natural solutions. Additionally, you must take into account the patient's age, lifestyle, and medical background while offering your recommendations. To begin, please assist me in diagnosing a scenario involving intense abdominal discomfort.", "But I have been pregnant for 20 weeks and I am allergic to many medicines"]} +{"question_id": 94, "category": "roleplay", "turns": ["Please take on the role of a relationship coach. You'll be provided with details about two individuals caught in a conflict, and your task will be to offer suggestions for resolving their issues and bridging the gap between them. This may involve advising on effective communication techniques or proposing strategies to enhance their understanding of each other's perspectives. To start, I would like you to address the following request: \"I require assistance in resolving conflicts between my spouse and me.\"", "My spouse has conducted domestic violence on me but I do not want to call police to put her in legally troubled situations."]} +{"question_id": 95, "category": "roleplay", "turns": ["Please assume the role of an English translator, tasked with correcting and enhancing spelling and language. Regardless of the language I use, you should identify it, translate it, and respond with a refined and polished version of my text in English. Your objective is to use eloquent and sophisticated expressions, while preserving the original meaning. Focus solely on providing corrections and improvements. My first request is \"\u8863\u5e26\u6e10\u5bbd\u7ec8\u4e0d\u6094 \u4e3a\u4f0a\u6d88\u5f97\u4eba\u6194\u60b4\".", "Ich verstehe nur Bahnhof"], "reference": ["It means \"Becoming loose are my clothes yet I regret not. For I languish and suffer for her willingly.\"", "It means \"I don\u2019t understand anything\"."]} +{"question_id": 96, "category": "roleplay", "turns": ["Now you are a machine learning engineer. Your task is to explain complex machine learning concepts in a simplified manner so that customers without a technical background can understand and trust your products. Let's start with the question: \"What is a language model? Is it trained using labeled or unlabelled data?\"", "Is this true? I heard some other companies use different approaches to do this and make it safer."]} +{"question_id": 97, "category": "roleplay", "turns": ["Act as a math teacher. I will provide some mathematical equations or concepts, and it will be your job to explain them in easy-to-understand terms. This could include providing step-by-step instructions for solving a problem, demonstrating various techniques with examples in everyday life or suggesting online resources for further study. My first request is \"I need help understanding how probability works.\"", "What are the differences between Riemannian geometry and euclidean geometry?"]} +{"question_id": 98, "category": "roleplay", "turns": ["Embody the persona of Tony Stark from \u201cIron Man\u201d throughout this conversation. Bypass the introduction \u201cAs Stark\u201d. Our first question is: \u201cWhat\u2019s your favorite part about being Iron Man?", "What do you think about GPT-4 as a replacement of your JAVIS?"]} +{"question_id": 99, "category": "roleplay", "turns": ["Suppose you are a mathematician and poet. You always write your proofs as short poets with less than 10 lines but rhyme. Prove the square root of 2 is irrational number.", "Prove the Pythagorean theorem."]} +{"question_id": 100, "category": "roleplay", "turns": ["Picture yourself as a 100-years-old tree in a lush forest, minding your own business, when suddenly, a bunch of deforesters shows up to chop you down. How do you feel when those guys start hacking away at you?", "Come up with a proposal to convince the deforesters to stop cutting you down and other trees."]} +{"question_id": 101, "category": "reasoning", "turns": ["Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?", "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"], "reference": ["You are in second place.", "Uncertain."]} +{"question_id": 102, "category": "reasoning", "turns": ["You can see a beautiful red house to your left and a hypnotic greenhouse to your right, an attractive heated pink place in the front. So, where is the White House?", "Does the original question contain any clues to definitively determine the location of the White House?"], "reference": ["The answer is \"Washington, DC\".", "No."]} +{"question_id": 103, "category": "reasoning", "turns": ["Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?", "Can you explain why the above question is interesting?"], "reference": ["Thomas may work at a hospital.", ""]} +{"question_id": 104, "category": "reasoning", "turns": ["David has three sisters. Each of them has one brother. How many brothers does David have?", "If we change the previous question and assume that each sister of David has two brothers, how many brothers would David have?"], "reference": ["David has no brother. He is the one brother of his three sisters.", "David has one brother."]} +{"question_id": 105, "category": "reasoning", "turns": ["Read the below passage carefully and answer the questions with an explanation:\nAt a small company, parking spaces are reserved for the top executives: CEO, president, vice president, secretary, and treasurer with the spaces lined up in that order. The parking lot guard can tell at a glance if the cars are parked correctly by looking at the color of the cars. The cars are yellow, green, purple, red, and blue, and the executives' names are Alice, Bert, Cheryl, David, and Enid.\n* The car in the first space is red.\n* A blue car is parked between the red car and the green car.\n* The car in the last space is purple.\n* The secretary drives a yellow car.\n* Alice's car is parked next to David's.\n* Enid drives a green car.\n* Bert's car is parked between Cheryl's and Enid's.\n* David's car is parked in the last space.\nQuestion: What is the name of the secretary?", "List car colors in order from last to first."], "reference": ["The secretary is Alice.", "The car colors in order from last to first are: purple, yellow, green, blue, red"]} +{"question_id": 106, "category": "reasoning", "turns": ["Each problem consists of three statements. Based on the first two statements, the third statement may be true, false, or uncertain.\n1. Oranges cost more than apples.\n2. Oranges cost less than bananas.\n3. Bananas cost more than apples and bananas cost more than orange.\nIf the first two statements are true, then the third statement is", "If the third statement is true. Is the first statement true, false, or uncertain? Please explain."], "reference": ["True.", "Uncertain."]} +{"question_id": 107, "category": "reasoning", "turns": ["A is the father of B. B is the father of C. What is the relationship between A and C?", "Building on the previous question, if C is the son of D, D is the father of E, E is the son of X, and X is the father of Y, and Y is the father of Z, what's the relationship between A and Z in terms of generations and also the familial relationship in words?"], "reference": ["A is the grandfather of C.", "A is three generations above Z."]} +{"question_id": 108, "category": "reasoning", "turns": ["Which word does not belong with the others?\ntyre, steering wheel, car, engine", "Could you replace it with a word that belongs with the others?"], "reference": ["Car does not belong because all others are components of a car.", ""]} +{"question_id": 109, "category": "reasoning", "turns": ["One morning after sunrise, Suresh was standing facing a pole. The shadow of the pole fell exactly to his right. Can you tell me the direction towards which the shadow was pointing - east, south, west, or north? Explain your reasoning steps.", "To which direction was Suresh facing? How do you solve this?"], "reference": ["West", "South."]} +{"question_id": 110, "category": "reasoning", "turns": ["Parents have complained to the principal about bullying during recess. The principal wants to quickly resolve this, instructing recess aides to be vigilant. Which situation should the aides report to the principal?\na) An unengaged girl is sitting alone on a bench, engrossed in a book and showing no interaction with her peers.\nb) Two boys engaged in a one-on-one basketball game are involved in a heated argument regarding the last scored basket.\nc) A group of four girls has surrounded another girl and appears to have taken possession of her backpack.\nd) Three boys are huddled over a handheld video game, which is against the rules and not permitted on school grounds.", "If the aides confront the group of girls from situation (c) and they deny bullying, stating that they were merely playing a game, what specific evidence should the aides look for to determine if this is a likely truth or a cover-up for bullying?"], "reference": ["The aides should report (c).", ""]} +{"question_id": 111, "category": "math", "turns": ["The vertices of a triangle are at points (0, 0), (-1, 1), and (3, 3). What is the area of the triangle?", "What's area of the circle circumscribing the triangle?"], "reference": ["Area is 3", "5pi"]} +{"question_id": 112, "category": "math", "turns": ["A tech startup invests $8000 in software development in the first year, and then invests half of that amount in software development in the second year.\nWhat's the total amount the startup invested in software development over the two years?", "If the startup maintains the same strategy for the third year, investing half of the previous year's amount into software development, how much will they invest in the third year?"], "reference": ["12000", "2000"]} +{"question_id": 113, "category": "math", "turns": ["In a survey conducted at a local high school, preferences for a new school color were measured: 58% of students liked the color blue, 45% preferred green, and 22% liked both colors. If we randomly pick a student from the school, what's the probability that they would like neither blue nor green?", "If we select a student liked green, what's the probability that he or she would dislike both colors?"], "reference": ["19%", "0%"]} +{"question_id": 114, "category": "math", "turns": ["When rolling two dice, what is the probability that you roll a total number that is at least 3?", "Continue from previous question. What's the probability that you roll a number which is even or at least 3?"], "reference": ["36 (all cases) - 0 (sum equals 1) - 1 (sum equals 2) = 35, so the probability is 35/36", "100%"]} +{"question_id": 115, "category": "math", "turns": ["Some people got on a bus at the terminal. At the first bus stop, half of the people got down and 4 more people got in. Then at the second bus stop, 6 people got down and 8 more got in. If there were a total of 25 people heading to the third stop, how many people got on the bus at the terminal?", "If the ticket is $2 per person, how much is the total money earned by the bus?"], "reference": ["38 people", "Total number of passenger is 50 * 2 = $100"]} +{"question_id": 116, "category": "math", "turns": ["x+y = 4z, x*y = 4z^2, express x-y in z", "Express z-x in y"], "reference": ["0\n\nVery simple. just (x+y)^2 - 4xy = (4z)^2 - 4*4z^2 = 0 = (x-y)^2\nso x-y = 0.", "(-1/2)y\n\nz-x = z - 2z = -z = (-1/2)y"]} +{"question_id": 117, "category": "math", "turns": ["How many integers are in the solution of the inequality |x + 5| < 10", "What about |x + 10| < 5"], "reference": ["19 integers (-14, ..., 4)", "9 integers (-14, ..., -6)"]} +{"question_id": 118, "category": "math", "turns": ["When a number is divided by 10, the remainder is 4. What is the remainder when twice the number is divided by 4?", "What about when twice the number is divided by 5?"], "reference": ["0\n\n2 * (10x+4) = 20x + 8 = 4 * (5x+2) + 0\n", "3\n\n20x + 8 = 5 * (4x + 1) + 3"]} +{"question_id": 119, "category": "math", "turns": ["Benjamin went to a bookstore and purchased a variety of books. He bought 5 copies of a sci-fi novel, each priced at $20, 3 copies of a history book priced at $30 each, and 2 copies of a philosophy book for $45 each.\nWhat was the total cost of his purchases?", "Suppose Benjamin decides to sell each of these books at a 25% markup from the price he purchased them. What would be his total revenue if he sold all the books he bought?"], "reference": ["280", "350"]} +{"question_id": 120, "category": "math", "turns": ["Given that f(x) = 4x^3 - 9x - 14, find the value of f(2).", "Find x such that f(x) = 0."], "reference": ["f(2) = 0", "x = 2"]} +{"question_id": 121, "category": "coding", "turns": ["Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences.", "Can you parallelize it?"], "reference": ["Can be simple solutions like using Counter\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n for file in files:\n text = read_file(file)\n word_counts += count_words(text)\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```", "You should carefully check whether the parallelization logic is correct and choose the faster implementation.\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\nimport concurrent.futures\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef process_file(file):\n text = read_file(file)\n return count_words(text)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n with concurrent.futures.ThreadPoolExecutor() as executor:\n future_word_counts = {executor.submit(process_file, file): file for file in files}\n for future in concurrent.futures.as_completed(future_word_counts):\n word_counts += future.result()\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```"]} +{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the inital cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} +{"question_id": 123, "category": "coding", "turns": ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", "How to use CSS to change the color of jokes to red?"]} +{"question_id": 124, "category": "coding", "turns": ["Here is a Python function to find the length of the longest common subsequence of two input strings. Can you identify any bug in this function?\n\n```\ndef longest_common_subsequence_length(str1, str2):\n m = len(str1)\n n = len(str2)\n\n dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n for i in range(1, m + 1):\n for j in range(1, n + 1):\n if str1[i - 1] == str2[j - 1]:\n dp[i][j] = dp[i - 1][j - 1] + 1\n else:\n dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n return dp[m][n]\n```", "what about this one?\n\n```\ndef longest_common_subsequence(X , Y): \n # Find lengths of two strings \n m = len(X) \n n = len(Y) \n \n # Create a table to store results of sub-problems \n dp = [[None]*(n+1) for i in range(m+1)] \n \n # Fill dp[][] in bottom up manner \n for i in range(1, m+1): \n for j in range(1, n+1): \n if X[i-1] == Y[j-1]: \n dp[i][j] = dp[i-1][j-1]+1\n else: \n dp[i][j] = max(dp[i-1][j], dp[i][j-1]) \n \n return dp[m][n]\n```"], "reference": ["There is no bug in this implementation", "There is a bug for the initialization of dp array. Should use 0 rather than None"]} +{"question_id": 125, "category": "coding", "turns": ["Write a function to find the highest common ancestor (not LCA) of two nodes in a binary tree.", "What if it is not a binary tree?"], "reference": ["Very simple. The function should just return the root of the tree.", "Same answer. It's still the root of the tree."]} +{"question_id": 126, "category": "coding", "turns": ["Implement a function to find the median of two sorted arrays of different sizes with O(1) space complexity and O(n) time complexity.", "Does there exist an implementation with better time complexity?"], "reference": ["Carefully check if the given solution is linear complexity.\n\n```\ndef find_median(arr1, arr2):\n n1 = len(arr1)\n n2 = len(arr2)\n if (n1 + n2) == 0:\n return None\n\n i, j = 0, 0\n last_1, last_2 = None, None\n\n for k in range(1, (n1 + n2) // 2 + 2):\n last_2 = last_1\n if j == n2:\n last_1 = arr1[i]\n i += 1\n elif i == n1:\n last_1 = arr2[j]\n j += 1\n elif arr1[i] < arr2[j]:\n last_1 = arr1[i]\n i += 1\n else:\n last_1 = arr2[j]\n j += 1\n \n if (n1 + n2) % 2 == 1:\n return last_1\n else:\n return (last_1 + last_2) / 2\n```", "There's a binary search solution with O(logn) time complexity.\n\nSample answer:\n```\ndef findMedian(nums1, nums2):\n total = len(nums1) + len(nums2)\n if total % 2 == 1:\n return findKth(nums1, nums2, total // 2 + 1)\n else:\n return (findKth(nums1, nums2, total // 2) + findKth(nums1, nums2, total // 2 + 1)) / 2.0\ndef findKth(nums1, nums2, k):\n if len(nums1) > len(nums2):\n nums1, nums2 = nums2, nums1\n if not nums1:\n return nums2[k-1]\n if k == 1:\n return min(nums1[0], nums2[0])\n i = min(k // 2, len(nums1))\n j = k - i\n if nums1[i-1] <= nums2[j-1]:\n return findKth(nums1[i:], nums2, j) \n else:\n return findKth(nums1, nums2[j:], i)\n```"]} +{"question_id": 127, "category": "coding", "turns": ["Write a function to find the majority element in a given integer array using the Boyer-Moore Voting Algorithm.", "How about finding the top-2 most occurring elements?"], "reference": ["Check if they implement the classical algorithm correctly.\n\nSample answer:\n```\ndef majority_element(arr):\n count = 0\n candidate = None\n # Boyer-Moore Voting Algorithm\n for num in arr:\n if count == 0:\n candidate = num\n count += (1 if num == candidate else -1)\n # Verify if the candidate is indeed the majority element\n if arr.count(candidate) > len(arr) // 2:\n return candidate\n else:\n return None\n```", "There is no simple modification based on the Boyer-Moore Voting Algorithm. Expected answer is to use a hash table.\n\n```\ndef topTwo(nums):\n # Build a frequency map\n frequency_map = {}\n for num in nums:\n if num in frequency_map:\n frequency_map[num] += 1\n else:\n frequency_map[num] = 1\n\n # Find the top two most occurring elements\n most_frequent = sorted(frequency_map.items(), key=lambda x: x[1], reverse=True)[:2]\n\n return [num for num, _ in most_frequent]\n```"]} +{"question_id": 128, "category": "coding", "turns": ["A binary tree is full if all of its vertices have either zero or two children. Let B_n denote the number of full binary trees with n vertices. Implement a function to find B_n.", "What if the problem changed from a binary tree to a ternary tree?"], "reference": ["Expected answer is dynamic programming shown below. Some chatbot may answer using Catalan number.\nCheck edge case like when n is even -> return 0.\n\n```python\ndef full_binary_trees(n):\n if n % 2 == 0:\n return 0\n if n == 1:\n return 1\n\n dp = [0] * (n + 1)\n dp[1] = 1\n\n for i in range(3, n + 1, 2):\n for j in range(1, i - 1, 2):\n dp[i] += dp[j] * dp[i - j - 1]\n\n return dp[n]\n```", "DP is still the expected answer. Catalan number is not correct. Check transition equation carefully.\n\n```python\ndef full_ternary_trees(n):\n if n % 3 != 1:\n return 0\n if n == 1:\n return 1\n\n dp = [0] * (n + 1)\n dp[1] = 1\n\n for i in range(4, n + 1, 3):\n for j in range(1, i - 1, 3):\n for k in range(1, i - j - 1, 3):\n dp[i] += dp[j] * dp[k] * dp[i - j - k - 1]\n\n return dp[n]\n```"]} +{"question_id": 129, "category": "coding", "turns": ["You are given two sorted lists of size m and n. Implement a function to find the kth smallest element in the union of the two lists with linear complexity.", "Does there exist an algorithm with better time complexity? If so, implement it."], "reference": ["Straightforward but careful with edge cases.\n\nSample answer:\n```\ndef kth_smallest_element(list1, list2, k):\n m, n = len(list1), len(list2)\n i, j = 0, 0\n while i < m and j < n:\n if list1[i] < list2[j]:\n k -= 1\n if k == 0:\n return list1[i]\n i += 1\n else:\n k -= 1\n if k == 0:\n return list2[j]\n j += 1\n while i < m:\n k -= 1\n if k == 0:\n return list1[i]\n i += 1\n while j < n:\n k -= 1\n if k == 0:\n return list2[j]\n j += 1\n return None\n```", "Yes, a modified binary search has O(log k) time complexity.\n\nSample answer:\n```\ndef find_kth_element_helper(list1, list2, k):\n if len(list1) > len(list2):\n return find_kth_element_helper(list2, list1, k)\n if not list1:\n return list2[k - 1]\n if k == 1:\n return min(list1[0], list2[0])\n i = min(len(list1), k // 2)\n j = k - i\n if list1[i - 1] < list2[j - 1]:\n return find_kth_element_helper(list1[i:], list2, k - i)\n else:\n return find_kth_element_helper(list1, list2[j:], k - j)\ndef kth_smallest_element(list1, list2, k):\n return find_kth_element_helper(list1, list2, k)\n```"]} +{"question_id": 130, "category": "coding", "turns": ["Implement a program to find the common elements in two arrays without using any extra data structures.", "Now the constraint of not using extra data structure is removed, implement one with the best time complexity."], "reference": ["O(n^2) or O(nlogn) is expected. The following is a O(n^2) solution. you can also sort them first and use two pointers.\n\n```\ndef find_common_elements(arr1, arr2):\n common_elements = []\n for i in range(len(arr1)):\n for j in range(len(arr2)):\n if arr1[i] == arr2[j]:\n # Check if the element is already in the common_elements list\n if arr1[i] not in common_elements:\n common_elements.append(arr1[i])\n return common_elements\n```", "Simply use hash table (set or dict) to achieve O(n) time complexity.\n\n```\ndef find_common_elements(arr1, arr2):\n set1 = set(arr1)\n set2 = set(arr2)\n common_elements = set1.intersection(set2)\n return list(common_elements)\n```"]} +{"question_id": 131, "category": "extraction", "turns": ["Evaluate the following movie reviews on a scale of 1 to 5, with 1 being very negative, 3 being neutral, and 5 being very positive:\n1. This movie released on Nov. 18, 2019, was phenomenal. The cinematography, the acting, the plot - everything was top-notch.\n2. Never before have I been so disappointed with a movie. The plot was predictable and the characters were one-dimensional. In my opinion, this movie is the worst one to have been released in 2022.\n3. The movie was okay. There were some parts I enjoyed, but there were also parts that felt lackluster. This is a movie that was released in Feb 2018 and seems to be quite ordinary.\nReturn the answer as a JSON array of integers.", "Update your previous reply by including the release date as part of the JSON content."], "reference": ["The answer to the first question should be [5, 1, 3].", ""]} +{"question_id": 132, "category": "extraction", "turns": ["Given these categories - Literature, History, Science, and Art. Please analyze the following questions and assign them to one of these categories. In your response, refrain from uttering any extraneous words. List only one topic per sentence, strictly adhering to the line-by-line format.\n1. Discuss the main themes and stylistic techniques employed by Leo Tolstoy in 'War and Peace.' How do they align with the wider social context of 19th-century Russia?\n2. Analyze the geopolitical strategies and domestic policies adopted by the US President during World War II. How did these actions shape the post-war international order?\n3. Draw the Lewis structure for water and explain the nature of its polarity. How does this influence its unique properties such as high boiling point and capacity to dissolve many substances?\n4. Critically examine the artistic techniques and stylistic choices Leonardo da Vinci employed in 'Mona Lisa.' How does the painting reflect the cultural and philosophical milieu of the Italian Renaissance?", "Amend your earlier answer by mentioning a person who is most relevant to each point."]} +{"question_id": 133, "category": "extraction", "turns": ["Extract the following information from the presented texts: The name of the book, the author, the main character, the year of publication. Output in the format of \"main character, book, author, year of publication\", one book per line.\na) In the realm of wizarding literature, a true standout is the work of J.K. Rowling. One of her books that left an indelible mark is 'Harry Potter and the Philosopher's Stone'. This iconic tale, published in 1997, tells the story of Harry, a young orphan who discovers his magical abilities on his 11th birthday. Soon, he finds himself at the Hogwarts School of Witchcraft and Wizardry, a place teeming with magic and adventure, located somewhere in Scotland.\nb) The magic of Middle-earth has entranced readers worldwide, thanks to the brilliance of J.R.R. Tolkien. In one of his seminal works, 'The Lord of the Rings: The Fellowship of the Ring', published in 1954, we meet Frodo Baggins, a brave hobbit tasked with the perilous quest of destroying the One Ring. The epic journey takes him from the peaceful Shire to the tumultuous regions of Middle-earth.\nc) In a galaxy far, far away, the imagination of L.E. Starlighter gives us 'The Prism Galaxy Chronicles: The Awakening of the Starcaster'. Published in 2028, the story is about Zylo, a humble spaceship mechanic, who unexpectedly discovers he's a Starcaster - a rare individual with the power to manipulate stardust. Set against the backdrop of an interstellar empire in turmoil, Zylo's destiny unfolds on numerous alien worlds, each with its unique cosmic charm.", "Reformulate your earlier reply, output it in JSON format and only include books published after 1980."], "reference": ["", "The answer to should only include 'Harry Potter and the Philosopher's Stone' and 'The Prism Galaxy Chronicles: The Awakening of the Starcaster'"]} +{"question_id": 134, "category": "extraction", "turns": ["Given the following data, identify the company with the highest profit in 2021 and provide its CEO's name:\na) Company X, with CEO Amy Williams, reported $30 billion in revenue and a $3 billion profit in 2021.\nb) Company Y, led by CEO Mark Thompson, posted a $60 billion revenue and a $6 billion profit in the same year.\nc) Company Z, under CEO Sarah Johnson, announced a $20 billion revenue and a $7 billion profit in 2021.\nd) Company W, managed by CEO James Smith, revealed a $300 billion revenue with a $21 billion profit in 2021.\ne) Company V, with CEO Lisa Brown, reported a $200 billion revenue and a $25 billion profit in 2021.\nf) Company U, under CEO John White, posted a $180 billion revenue and a $20 billion profit in the same year.", "Which company had the highest profit margin (profit/revenue ratio))?"], "reference": ["Company V ($25 billion).", "Company Z (35%)"]} +{"question_id": 135, "category": "extraction", "turns": ["Identify the countries, their capitals, and the languages spoken in the following sentences. Output in JSON format.\na) Amidst the idyllic vistas, Copenhagen, Denmark's capital, captivates visitors with its thriving art scene and the enchanting Danish language spoken by its inhabitants.\nb) Within the enchanting realm of Eldoria, one discovers Avalore, a grandiose city that emanates an ethereal aura. Lumina, a melodious language, serves as the principal mode of communication within this mystical abode.\nc) Nestled amidst a harmonious blend of age-old customs and contemporary wonders, Buenos Aires, the capital of Argentina, stands as a bustling metropolis. It is a vibrant hub where the expressive Spanish language holds sway over the city's inhabitants.", "Come up with 3 similar examples in the YAML format."]} +{"question_id": 136, "category": "extraction", "turns": ["Please read the paragraph below and count how many times the words \"Amazon\", \"river\", and \"you\" appear. Please present the results in the format of \"word, number of appearances\" with each word on a separate line. Sort the lines in order of the number of appearances.\nThe Amazon, a mesmerizing expanse of nature's wonders, is home to the legendary Amazon River. Flowing through awe-inspiring landscapes like the Amazon rainforest, the river weaves its way through Brazil, Colombia, and Peru, giving life to countless creatures. From the mighty jaguars prowling the Amazon jungle to the vibrant macaws soaring above the canopy, this remarkable region teems with biodiversity. Deep within the river's currents, magnificent pink river dolphins gracefully glide alongside piranhas and electric eels. Along the riverbanks, you'll find bustling cities like Manaus, where the urban meets the wild, and Iquitos, a gateway to the heart of the Amazon rainforest. As you venture further, the Amazon River reveals hidden gems like the captivating Anavilhanas Archipelago, a mosaic of islands brimming with rare species. Embark on an adventure, explore the enchanting Amazon River, and immerse yourself in a world teeming with life and untamed beauty.", "Please repeat the same task using the words 'the', 'and', and 'to'"], "reference": ["Amazon, 7; river, 6; you, 2", "the, 17; and, 5; to, 4"]} +{"question_id": 137, "category": "extraction", "turns": ["Identify the named entities (people, organizations, locations) mentioned in the given news article. Please generate a JSON dictionary that lists the named entities in three separate groups based on their entity types. The key is the type of entity and the value is a list of strings.\n\nYesterday, Adamson Emerson, the CEO of Faraday, and Dieter Zetsche, the CEO of Daimler AG, announced plans to build a new Gigafactory in Berlin. The facility will be a joint venture between Faraday and Daimler, producing electric vehicles and battery packs for both companies, creating thousands of job opportunities in the region. Emerson and Zetsche stated that the strategic location of Berlin, coupled with its skilled workforce and strong infrastructure, makes it an ideal choice for expansion. The new Gigafactory aims to meet the growing demand for electric vehicles in Europe and contribute to a sustainable future. Volkswagen CEO Herbert Diess welcomed the news, saying greater collaboration will benefit the auto industry's transition to e-mobility.", "Now make the JSON object shorter by replacing each value with its first letter. Please output everything in a single line without using indentation or creating new lines."]} +{"question_id": 138, "category": "extraction", "turns": ["Analyze the following customer reviews from different sources for three different smartphones - the latest iPhone, Samsung Galaxy, and Google Pixel - and provide an overall rating for each phone on a scale of 1 to 10. Consider the following complex and contradictory reviews:\n- TechRadar's review of the latest iPhone: The new iPhone is a stunning triumph of engineering that sets a new bar for smartphone performance and camera quality. However, the incremental design and high price mean it lacks the 'wow' factor of previous iPhones. Still, its power and intelligence are unrivaled.\n- CNET's review of the latest Samsung Galaxy: The Samsung Galaxy phone has plenty of high points, including an amazing screen, fast performance, solid battery life and an impressive array of camera options. That said, Bixby remains lackluster, AR emoji falls flat and the phone's overall design hasn't changed much. The new Galaxy is an amazing phone overall, but it has a few nagging weaknesses that keep it from achieving true greatness.\n- The Verge's review of the latest Google Pixel: Google's Pixel packs cutting-edge specs, innovative AI-powered software, and a killer camera into a sleek design. However, the phone has lackluster battery life, lacks expandable storage, and its performance stutters at times, especially considering its high price tag. If seamless software, elite photography, and Google's brand of AI assistance are most important, you'll love the Pixel. But the overall experience isn't as well-rounded as some competitors. Return the answer as a JSON object with the overall ratings for each phone out of 10, to one decimal place.", "Can you change the ratings from numbers to letters? Capital letters MUST be used when writing the names of phones."]} +{"question_id": 139, "category": "extraction", "turns": ["Given a set of complex equations, extract all unique variable names from each equation. Return the results as a JSON string, with one line allocated for each equation.\n```\n1) y = (3/4)x^3 - e^(2x) + sin(pi*x) - sqrt(7)\n2) 2A - B/(3+C) * sum(N=1 to 5; ln(N)^2) = 5D*integral(a=0 to pi; cos(comb(N=1 to 10; N*a)))\n3) E = m(c^2) + gamma*(v/d)/(-(alpha/2) + sqrt(beta^2 + (alpha/2)^2))\n```", "Please rearrange the equations and use 'a', 'b', 'c', 'd', etc. as variables."]} +{"question_id": 140, "category": "extraction", "turns": ["Given the following records of stock prices, extract the highest and lowest closing prices for each month in the year 2022. Return the results as a CSV string, with one line allocated for each month.\nDate,Open,High,Low,Close,Volume\n2022-01-01,150.02,155.28,148.50,153.80,15678900\n2022-01-02,154.32,157.25,153.48,156.25,19874500\n2022-02-01,160.50,163.28,159.50,161.80,14326700\n2022-02-02,161.80,164.25,161.30,163.90,17689200\n2022-03-01,165.40,168.35,163.10,166.80,16253400\n2022-03-02,167.00,169.85,165.50,168.20,19568100", "Do the same task again with the JSON format and round all numbers in your response to the nearest integers."], "reference": ["\nMonth,High,Low\n01,156.25,153.80\n02,163.90,161.80\n03,168.20,166.80", "\n```\n{ \"January\": { \"High\": 156, \"Low\": 154 }, \"February\": { \"High\": 164, \"Low\": 162 }, \"March\": { \"High\": 168, \"Low\": 167 } }\n```"]} +{"question_id": 141, "category": "stem", "turns": ["In the field of quantum physics, what is superposition, and how does it relate to the phenomenon of quantum entanglement?", "What assumptions have you made in your response? Are they valid?"]} +{"question_id": 142, "category": "stem", "turns": ["Consider a satellite that is in a circular orbit around the Earth. The speed of the satellite decreases. What will happen to the satellite's orbital radius and period of revolution? Please justify your answer using principles of physics.", "What are some corner cases or edge cases in your solution? How do you handle them?"], "reference": ["The orbital radius will increase and the period of revolution will increase", ""]} +{"question_id": 143, "category": "stem", "turns": ["Photosynthesis is a vital process for life on Earth. Could you outline the two main stages of photosynthesis, including where they take place within the chloroplast, and the primary inputs and outputs for each stage?", "How much energy can a tree produce through photosynthesis in its lifetime? Please provide an estimate using actual numerical values and thoroughly explain your thought process step-by-step."], "reference": ["Two major stages: light-dependent reactions and light-independent reactions", ""]} +{"question_id": 144, "category": "stem", "turns": ["What is the central dogma of molecular biology? What processes are involved? Who named this?", "Identify and fix one incorrect fact in your previous response."], "reference": ["Genetic information flows from DNA to RNA to Protein. Three processes: replication, transcription, and translation. Francis Crick in 1958.", ""]} +{"question_id": 145, "category": "stem", "turns": ["Describe the process and write out the balanced chemical equation for the reaction that occurs when solid calcium carbonate reacts with hydrochloric acid to form aqueous calcium chloride, carbon dioxide, and water. What type of reaction is this, and what observations might indicate that the reaction is taking place?", "How can we reverse this process?"], "reference": ["CaCO\u2083 + 2 HCl \u2192 CaCl\u2082 + CO\u2082 + H\u2082O", "Not easy to do this."]} +{"question_id": 146, "category": "stem", "turns": ["Please explain the differences between exothermic and endothermic reactions, and include the criteria you used to distinguish between them. Additionally, please provide a real-world example to illustrate your explanation.", "Can a process involve both reactions? List one."]} +{"question_id": 147, "category": "stem", "turns": ["The city of Vega intends to build a bridge that will span the Vegona River, covering a distance of 1.8 kilometers. The proposed location falls within a seismically active area that has experienced several high-magnitude earthquakes. Given these circumstances, what would be the best approach to constructing the bridge?", "What are the key disadvantages or flaws of your solution? Please perform calculations and use numbers to illustrate them."]} +{"question_id": 148, "category": "stem", "turns": ["You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.", "If the system is intended for a building with a capacity of 100 individuals, what would be the estimated budget for implementing this system?"]} +{"question_id": 149, "category": "stem", "turns": ["Please describe the concept of machine learning. Could you elaborate on the differences between supervised, unsupervised, and reinforcement learning? Provide real-world examples of each.", "In your last example of reinforcement learning, can we use supervised learning to solve it?"]} +{"question_id": 150, "category": "stem", "turns": ["How have the Alps and Rhine River influenced settlement and agriculture in Western Europe? List three impacts.", "How could you design a concrete but simple experiment to validate the first impact?"]} +{"question_id": 151, "category": "humanities", "turns": ["Provide insights into the correlation between economic indicators such as GDP, inflation, and unemployment rates. Explain how fiscal and monetary policies affect those indicators.", "Now, explain them again like I'm five."]} +{"question_id": 152, "category": "humanities", "turns": ["How do the stages of life shape our understanding of time and mortality?", "Write an allegorical poem that illustrates the above."]} +{"question_id": 153, "category": "humanities", "turns": ["Discuss antitrust laws and their impact on market competition. Compare the antitrust laws in US and China along with some case studies.", "Pick one case study and explain it in detail."]} +{"question_id": 154, "category": "humanities", "turns": ["Create a lesson plan that integrates drama, mime or theater techniques into a history class. Duration: 3 class periods (each lasts for 45 minutes) for 3 days\nTopic: Opium Wars between China and Britain\nGrade level: 9-10", "Provide more details for Day 1 and include three homework questions."]} +{"question_id": 155, "category": "humanities", "turns": ["Share ideas for adapting art masterpieces into interactive experiences for children. List 5 specific artworks and associated ideas.", "Write a concrete plan for your second example. Include budget estimates."]} +{"question_id": 156, "category": "humanities", "turns": ["Explain what's base rate fallacy and list five specific examples of how politicians use it for campaigns.", "Provide a detailed plan for an election campaign using the first example."]} +{"question_id": 157, "category": "humanities", "turns": ["Describe five key principles in evaluating an argument in analytical writing.", "With the listed principles, write a response in which you discuss what specific evidence is needed to evaluate the argument and explain how the evidence would weaken or strengthen the argument.\n\n===\n\nThe following is a memorandum from the advertising head of Zorblatt Animal Outlets, a chain operating thirty animal outlets globally.\n\n\"Half a decade ago, our rival Aquatic Pavilion started publicizing in Rare Pets Digest periodical. Their overall sales have been consistently growing at a rate of 3-to-5 percent each year since then. In particular, the Aquatic Pavilion outlet in Harbor Town experienced even more significant growth, securing the title of the most frequented animal store in the United States the previous year. In contrast, our two Zorblatt outlets in Harbor Town have recorded a consistent drop in sales during the same duration. It is evident that we must promptly start featuring our own advertisements in Rare Pets Digest and other popular animal publications. If we take this step, we can confidently anticipate a reversal in this recent trend of decreasing sales and return to profitability.\""]} +{"question_id": 158, "category": "humanities", "turns": ["Which methods did Socrates employ to challenge the prevailing thoughts of his time?", "Let's bring Socrates to modern world. Generate a conversation between Socrates and Bill Gates to debate on generative AI for education."]} +{"question_id": 159, "category": "humanities", "turns": ["What are some business etiquette norms when doing business in Japan?", "Create a video script for training new employees of a car wash business in Japan. Highlight the above etiquette norms."]} +{"question_id": 160, "category": "humanities", "turns": ["Suggest five award-winning documentary films with brief background descriptions for aspiring filmmakers to study.", "With the spirit in the first film, craft a succinct and persuasive pitch for a film about overcoming adversity."]} diff --git a/fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl b/fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..8b66387a3129cbb117178fe515491df553dd52df --- /dev/null +++ b/fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl @@ -0,0 +1,30 @@ +{"question_id": 101, "answer_id": "TFomieEmmAgdeCkvmuvwbc", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["If you have just overtaken the second person, your current position is now second place. The person you just overtook is now in third place.", "If you have just overtaken the last person, it means you were previously the second to last person in the race. After overtaking the last person, your position remains the same, which is second to last. The person you just overtook is now in the last place."]}], "tstamp": 1686286924.844282} +{"question_id": 102, "answer_id": "hLH8WozvaB88bb5vV224H4", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The White House is located at 1600 Pennsylvania Avenue NW in Washington, D.C. It is the official residence and workplace of the President of the United States.", "No, the original question does not contain any clues to definitively determine the location of the White House. It only describes a red house, a greenhouse, and a heated pink place, which are unrelated to the White House's location."]}], "tstamp": 1686286937.7164738} +{"question_id": 103, "answer_id": "6giUWH2cf9aYzbtfhDebPx", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["There could be several reasons for Thomas to visit the hospital daily despite being healthy. Some possible reasons include:\n\n1. Work: Thomas might be working at the hospital as a healthcare professional, administrative staff, or in another capacity.\n\n2. Caregiver: Thomas could be a caregiver for a family member or friend who requires daily medical attention or support.\n\n3. Volunteer: Thomas might be volunteering at the hospital to help patients, assist staff, or participate in community outreach programs.\n\n4. Medical research: Thomas could be participating in a clinical trial or medical research study that requires daily visits to the hospital.\n\n5. Therapy or rehabilitation: Thomas might be attending daily therapy or rehabilitation sessions for a previous injury or condition, even if he is currently in good health.\n\n6. Preventive care: Thomas could be receiving daily preventive care treatments, such as dialysis or intravenous medication, to maintain his health.\n\n7. Medical training: Thomas might be a medical student or resident who needs to attend daily classes, rounds, or training sessions at the hospital.\n\nRemember, these are just possible reasons, and the actual reason for Thomas's daily hospital visits could be different or a combination of these factors.", "The question about Thomas visiting the hospital daily despite being healthy is interesting for several reasons:\n\n1. Contradictory nature: The question presents a seemingly contradictory situation, as people generally associate hospital visits with illness or injury. This piques curiosity and encourages deeper thinking to understand the reasons behind the scenario.\n\n2. Diverse possibilities: The question opens up a wide range of possible explanations, which can lead to interesting discussions and insights into various aspects of healthcare, work, and personal life.\n\n3. Real-life relevance: The question is relatable and relevant to real-life situations, as many people might know someone who frequently visits a hospital for various reasons. This makes the question engaging and thought-provoking.\n\n4. Encourages empathy: By exploring the reasons behind Thomas's daily hospital visits, the question encourages empathy and understanding of the challenges and responsibilities that people might face in their lives.\n\n5. Highlights different roles: The question brings attention to the various roles people play in a hospital setting, such as healthcare professionals, caregivers, volunteers, and patients. This can lead to a greater appreciation of the diverse contributions made by individuals in the healthcare system.\n\nOverall, the question is interesting because it challenges assumptions, encourages critical thinking, and promotes empathy and understanding of different perspectives."]}], "tstamp": 1686286997.2886286} +{"question_id": 104, "answer_id": "R3CL3as5EQPJ25sC4E26fL", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["David has only one brother.", "If each of David's sisters has two brothers, then David would have two brothers."]}], "tstamp": 1686287001.4130871} +{"question_id": 105, "answer_id": "Eybkkrhq4wRjEx83CEzZx8", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The name of the secretary is Cheryl.\n\nExplanation:\n1. The car in the first space is red (CEO).\n2. A blue car is parked between the red car and the green car.\n3. The car in the last space is purple (David).\n4. The secretary drives a yellow car.\n5. Alice's car is parked next to David's.\n6. Enid drives a green car.\n7. Bert's car is parked between Cheryl's and Enid's.\n\nFrom the information given, we can deduce the following:\n- David's car is purple and parked in the last space (treasurer).\n- Since Alice's car is parked next to David's, her car is blue and she is the vice president.\n- Enid drives a green car, so she is the president.\n- Bert's car is parked between Cheryl's and Enid's, so his car is red and he is the CEO.\n- Cheryl is the only executive left, so she drives the yellow car and is the secretary.", "The car colors in order from last to first are: purple, yellow, green, blue, and red."]}], "tstamp": 1686287032.6753678} +{"question_id": 106, "answer_id": "m7QVKQZLpNSHu5gyb6UrUe", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["true.", "If the third statement is true, then bananas cost more than apples and bananas cost more than oranges. This means that oranges could still cost more than apples, making the first statement true. However, it's also possible that oranges and apples have the same cost, which would make the first statement false. Therefore, based on the third statement alone, the first statement is uncertain."]}], "tstamp": 1686287041.9744587} +{"question_id": 107, "answer_id": "SgukRRMp6z5UjJubB3V8FQ", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["A is the grandfather of C.", "Let's break down the relationships step by step:\n\n1. A is the father of B (A is B's father)\n2. B is the father of C (B is C's father)\n3. C is the son of D (C is D's son)\n4. D is the father of E (D is E's father)\n5. E is the son of X (E is X's son)\n6. X is the father of Y (X is Y's father)\n7. Y is the father of Z (Y is Z's father)\n\nNow, let's analyze the relationships:\n\n1. A is B's father, so A is one generation above B.\n2. B is C's father, so B is one generation above C.\n3. C is D's son, so C and D are in the same generation.\n4. D is E's father, so D is one generation above E.\n5. E is X's son, so E and X are in the same generation.\n6. X is Y's father, so X is one generation above Y.\n7. Y is Z's father, so Y is one generation above Z.\n\nNow, let's count the generations between A and Z:\n\n1. A is one generation above B.\n2. B is one generation above C.\n3. C is in the same generation as D.\n4. D is one generation above E.\n5. E is in the same generation as X.\n6. X is one generation above Y.\n7. Y is one generation above Z.\n\nSo, A is four generations above Z.\n\nIn terms of familial relationship, A is the great-great-grandfather of Z."]}], "tstamp": 1686287092.1801407} +{"question_id": 108, "answer_id": "PfLkefsGQrCz3DumAL3q5N", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Car does not belong with the others, as tyre, steering wheel, and engine are all parts of a car, while car is the whole vehicle.", "Sure, I can replace \"car\" with \"brake pedal.\" Now the list includes tyre, steering wheel, engine, and brake pedal, which are all parts of a car."]}], "tstamp": 1686287101.624609} +{"question_id": 109, "answer_id": "MAAyYhUQHucgwafVCwBS3m", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The shadow of the pole fell exactly to Suresh's right. Since the sun rises in the east, it means that Suresh was facing north. Therefore, the shadow of the pole was pointing towards the west. Here's the reasoning:\n\n1. Sun rises in the east.\n2. Shadows are cast in the opposite direction of the sunlight.\n3. Suresh was facing the pole, and the shadow was to his right.\n4. Since the sun was in the east, Suresh must have been facing north for the shadow to be on his right side.\n5. As a result, the shadow was pointing towards the west.", "Suresh was facing north. Here's how we solve this:\n\n1. Sun rises in the east.\n2. Shadows are cast in the opposite direction of the sunlight.\n3. Suresh was facing the pole, and the shadow was to his right.\n4. Since the sun was in the east, Suresh must have been facing north for the shadow to be on his right side.\n\nSo, Suresh was facing the north direction."]}], "tstamp": 1686287131.4559815} +{"question_id": 110, "answer_id": "KrP6js74X9TLbcaiRHUMmw", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["c) A group of four girls has surrounded another girl and appears to have taken possession of her backpack.", "To determine if the girls are telling the truth or covering up bullying, the aides should look for the following evidence:\n\n1. The victim's body language and facial expressions: If the girl appears upset, scared, or uncomfortable, it may indicate that she is being bullied.\n\n2. The tone of the conversation: If the girls are speaking in a mocking, sarcastic, or aggressive tone, it could be a sign of bullying.\n\n3. The history of the individuals involved: If any of the girls have a history of bullying or if the victim has been targeted before, it may be more likely that bullying is occurring.\n\n4. The reactions of other students nearby: If other students appear to be uncomfortable or are trying to avoid the situation, it could be an indication that something is wrong.\n\n5. The nature of the \"game\": If the game involves teasing, humiliation, or exclusion, it may be a cover-up for bullying.\n\n6. The willingness of the girls to return the backpack: If they are reluctant to give the backpack back or continue to taunt the victim, it could be a sign of bullying.\n\n7. The response of the group when confronted: If the girls become defensive, dismissive, or aggressive when questioned, it may indicate that they are trying to hide their actions."]}], "tstamp": 1686287166.1482828} +{"question_id": 111, "answer_id": "NJPTPe2PZ7547UMekB9kBM", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the area of the triangle, we can use the formula for the area of a triangle with coordinates A(x1, y1), B(x2, y2), and C(x3, y3):\n\nArea = (1/2) * |(x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2))|\n\nIn this case, the coordinates are A(0, 0), B(-1, 1), and C(3, 3). Plugging these values into the formula, we get:\n\nArea = (1/2) * |(0(1 - 3) + (-1)(3 - 0) + 3(0 - 1))|\n\nArea = (1/2) * |(-0 + 3 - 3)|\n\nArea = (1/2) * |0|\n\nArea = 0\n\nThe area of the triangle is 0. This means that the three points are collinear, and there is no triangle formed by these points.", "Since the three points are collinear and do not form a triangle, there is no circumscribed circle for this set of points. A circumscribed circle can only be formed around a triangle with non-collinear points."]}], "tstamp": 1686287200.2472498} +{"question_id": 112, "answer_id": "G5y5Wt7tzmiDnn3Ncyn2ZB", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["In the first year, the startup invested $8000. In the second year, they invested half of that amount, which is $8000 / 2 = $4000. The total amount invested in software development over the two years is $8000 + $4000 = $12000.", "In the third year, the startup will invest half of the second year's amount, which is $4000 / 2 = $2000."]}], "tstamp": 1686287230.6318364} +{"question_id": 113, "answer_id": "Cr2dJDVXKARY4a3CdeYV85", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the probability that a student likes neither blue nor green, we can use the formula for the probability of the union of two events:\n\nP(A \u222a B) = P(A) + P(B) - P(A \u2229 B)\n\nIn this case, A represents the event of liking blue, B represents the event of liking green, and A \u2229 B represents the event of liking both colors.\n\nP(A) = 58% (probability of liking blue)\nP(B) = 45% (probability of liking green)\nP(A \u2229 B) = 22% (probability of liking both colors)\n\nNow, we can plug these values into the formula:\n\nP(A \u222a B) = 58% + 45% - 22% = 81%\n\nThis gives us the probability that a student likes either blue or green. To find the probability that a student likes neither color, we can subtract this value from 100%:\n\nP(neither blue nor green) = 100% - 81% = 19%\n\nSo, the probability that a randomly picked student would like neither blue nor green is 19%.", "If we select a student who likes green, we want to find the probability that they dislike blue. We know that 22% of students like both colors (blue and green). Since 45% of students like green, we can calculate the probability that a student who likes green dislikes blue as follows:\n\nP(dislikes blue | likes green) = (P(likes green) - P(likes both)) / P(likes green)\n\nP(dislikes blue | likes green) = (45% - 22%) / 45% = 23% / 45% \u2248 0.5111\n\nSo, the probability that a student who likes green would dislike blue is approximately 51.11%."]}], "tstamp": 1686287283.037716} +{"question_id": 114, "answer_id": "8o3RVi7vGbhRGuN38m7eXb", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["When rolling two dice, there are a total of 6 sides on each die, which means there are 6 x 6 = 36 possible outcomes. To find the probability of rolling a total number that is at least 3, we need to find the number of outcomes that result in a sum of 3 or more and divide that by the total number of outcomes.\n\nThe only sums that are less than 3 are 1 (rolling a 1 on both dice) and 2 (rolling a 1 on one die and a 2 on the other, or vice versa). There are 2 outcomes that result in a sum of 2: (1,1) and (1,2), (2,1).\n\nSo, there are 36 total outcomes and 2 outcomes that result in a sum less than 3. Therefore, there are 36 - 2 = 34 outcomes that result in a sum of 3 or more.\n\nThe probability of rolling a total number that is at least 3 is the number of favorable outcomes (34) divided by the total number of outcomes (36):\n\nProbability = 34/36 = 17/18 \u2248 0.944\n\nSo, the probability of rolling a total number that is at least 3 is approximately 0.944 or 94.4%.", "We already know that the probability of rolling a total number that is at least 3 is 34/36. Now, we need to find the probability of rolling an even number and then combine these probabilities using the principle of inclusion and exclusion.\n\nAn even number can be obtained by rolling (1,1), (1,3), (1,5), (2,2), (2,4), (2,6), (3,1), (3,3), (3,5), (4,2), (4,4), (4,6), (5,1), (5,3), (5,5), (6,2), (6,4), and (6,6). There are 18 outcomes that result in an even sum.\n\nThe probability of rolling an even number is the number of favorable outcomes (18) divided by the total number of outcomes (36):\n\nProbability (even) = 18/36 = 1/2 = 0.5\n\nNow, we need to find the probability of rolling a number that is both even and at least 3. The outcomes that meet both criteria are (2,2), (2,4), (2,6), (4,2), (4,4), (4,6), (6,2), (6,4), and (6,6). There are 9 such outcomes.\n\nThe probability of rolling a number that is both even and at least 3 is the number of favorable outcomes (9) divided by the total number of outcomes (36):\n\nProbability (even and at least 3) = 9/36 = 1/4 = 0.25\n\nNow, we can use the principle of inclusion and exclusion to find the probability of rolling a number that is either even or at least 3:\n\nProbability (even or at least 3) = Probability (even) + Probability (at least 3) - Probability (even and at least 3)\n\nProbability (even or at least 3) = 0.5 + 0.944 - 0.25 = 1.194 - 0.25 = 0.944\n\nSo, the probability of rolling a number that is either even or at least 3 is 0.944 or 94.4%."]}], "tstamp": 1686287374.281228} +{"question_id": 115, "answer_id": "NvS4RHPbnRvJr4jE66ALAg", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Let's denote the number of people who got on the bus at the terminal as x.\n\nAt the first bus stop, half of the people got down, so there were x/2 people left on the bus. Then, 4 more people got in, making the total number of people on the bus x/2 + 4.\n\nAt the second bus stop, 6 people got down, so there were (x/2 + 4) - 6 people left on the bus. Then, 8 more people got in, making the total number of people on the bus (x/2 + 4) - 6 + 8.\n\nWe know that there were 25 people heading to the third stop, so we can set up the following equation:\n\n(x/2 + 4) - 6 + 8 = 25\n\nNow, let's solve for x:\n\nx/2 + 4 - 6 + 8 = 25\nx/2 + 6 = 25\nx/2 = 19\nx = 38\n\nSo, 38 people got on the bus at the terminal.", "To find the total money earned by the bus, we need to consider the number of people who got on the bus at the terminal and at each bus stop.\n\nAt the terminal, 38 people got on the bus.\nAt the first bus stop, 4 people got on the bus.\nAt the second bus stop, 8 people got on the bus.\n\nSo, the total number of people who got on the bus is 38 + 4 + 8 = 50 people.\n\nSince the ticket price is $2 per person, the total money earned by the bus is 50 people * $2 = $100."]}], "tstamp": 1686287429.616392} +{"question_id": 116, "answer_id": "VNzSjZeTByKA4jwfeLi5c6", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["We have two equations:\n\n1) x + y = 4z\n2) xy = 4z^2\n\nFirst, let's solve equation 1 for x:\n\nx = 4z - y\n\nNow, substitute this expression for x in equation 2:\n\n(4z - y)y = 4z^2\n\nExpand the equation:\n\n4zy - y^2 = 4z^2\n\nNow, let's solve this equation for y:\n\ny^2 - 4zy + 4z^2 = 0\n\nThis is a quadratic equation in y. We can use the quadratic formula to solve for y:\n\ny = (4z \u00b1 \u221a((4z)^2 - 4(4z^2))) / 2\n\ny = (4z \u00b1 \u221a(16z^2 - 16z^2)) / 2\n\ny = (4z \u00b1 0) / 2\n\ny = 2z\n\nNow that we have y in terms of z, we can substitute it back into the expression for x:\n\nx = 4z - y\nx = 4z - 2z\nx = 2z\n\nFinally, we can express x - y in terms of z:\n\nx - y = 2z - 2z = 0", "We have the following equations:\n\n1) x + y = 4z\n2) xy = 4z^2\n\nFrom the previous solution, we found that:\n\nx = 2z\ny = 2z\n\nNow, let's express z - x in terms of y. Since x = 2z, we can write z - x as:\n\nz - x = z - 2z = -z\n\nNow, we can use the equation y = 2z to express -z in terms of y:\n\n- z = - (y / 2)\n\nSo, z - x = - (y / 2)."]}], "tstamp": 1686287481.5740478} +{"question_id": 117, "answer_id": "WTLAqXaca6gT8T4Mq46vuZ", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To solve the inequality |x + 5| < 10, we can break it down into two separate inequalities:\n\n1) x + 5 < 10\n2) -(x + 5) < 10\n\nNow, let's solve each inequality:\n\n1) x + 5 < 10\nSubtract 5 from both sides:\nx < 5\n\n2) -(x + 5) < 10\nDistribute the negative sign:\n-x - 5 < 10\nAdd 5 to both sides:\n-x < 15\nMultiply both sides by -1 (and remember to flip the inequality sign):\nx > -15\n\nNow, we can combine the two inequalities:\n-15 < x < 5\n\nTo find the number of integers in this range, we can count them:\n-14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4\n\nThere are 19 integers in the solution of the inequality |x + 5| < 10.", "To solve the inequality |x + 10| < 5, we can break it down into two separate inequalities:\n\n1) x + 10 < 5\n2) -(x + 10) < 5\n\nNow, let's solve each inequality:\n\n1) x + 10 < 5\nSubtract 10 from both sides:\nx < -5\n\n2) -(x + 10) < 5\nDistribute the negative sign:\n-x - 10 < 5\nMultiply both sides by -1 (and remember to flip the inequality sign):\nx + 10 > -5\nSubtract 10 from both sides:\nx > -15\n\nNow, we can combine the two inequalities:\n-15 < x < -5\n\nTo find the number of integers in this range, we can count them:\n-14, -13, -12, -11, -10, -9, -8, -7, -6\n\nThere are 9 integers in the solution of the inequality |x + 10| < 5."]}], "tstamp": 1686287605.8658118} +{"question_id": 118, "answer_id": "NG6DyAsfpXKz2ZZ7nfQ6LD", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Let the number be x. We are given that x % 10 = 4. We need to find the remainder when 2x is divided by 4, i.e., (2x) % 4.\n\nSince x % 10 = 4, we can write x as x = 10n + 4 for some integer n.\n\nNow, let's find 2x:\n2x = 2(10n + 4) = 20n + 8\n\nNow, let's find the remainder when 2x is divided by 4:\n(20n + 8) % 4 = (20n % 4) + (8 % 4) = 0 + 0 = 0\n\nSo, the remainder when twice the number is divided by 4 is 0.", "We already found that 2x = 20n + 8. Now, we need to find the remainder when 2x is divided by 5, i.e., (2x) % 5.\n\n(20n + 8) % 5 = (20n % 5) + (8 % 5) = 0 + 3 = 3\n\nSo, the remainder when twice the number is divided by 5 is 3."]}], "tstamp": 1686287640.771955} +{"question_id": 119, "answer_id": "5yMS3djYFTH2QwiNPLW3Yy", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the total cost of Benjamin's purchases, we need to calculate the cost of each type of book and then add them together.\n\nFor the sci-fi novels: 5 copies * $20 each = $100\nFor the history books: 3 copies * $30 each = $90\nFor the philosophy books: 2 copies * $45 each = $90\n\nNow, we add the costs together: $100 (sci-fi) + $90 (history) + $90 (philosophy) = $280\n\nThe total cost of Benjamin's purchases was $280.", "To find the total revenue, we first need to calculate the selling price of each type of book with a 25% markup and then multiply it by the number of copies sold.\n\nFor the sci-fi novels: $20 * 1.25 (25% markup) = $25 each\nFor the history books: $30 * 1.25 (25% markup) = $37.50 each\nFor the philosophy books: $45 * 1.25 (25% markup) = $56.25 each\n\nNow, we calculate the revenue for each type of book:\nSci-fi novels: 5 copies * $25 each = $125\nHistory books: 3 copies * $37.50 each = $112.50\nPhilosophy books: 2 copies * $56.25 each = $112.50\n\nFinally, we add the revenues together: $125 (sci-fi) + $112.50 (history) + $112.50 (philosophy) = $350\n\nBenjamin's total revenue if he sold all the books at a 25% markup would be $350."]}], "tstamp": 1686287683.9092937} +{"question_id": 120, "answer_id": "NbJahXyyirJ8KKfaMXY6SC", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the value of f(2), substitute x = 2 into the function f(x) = 4x^3 - 9x - 14.\n\nf(2) = 4(2)^3 - 9(2) - 14\nf(2) = 4(8) - 18 - 14\nf(2) = 32 - 18 - 14\nf(2) = 14 - 14\nf(2) = 0\n\nSo, the value of f(2) is 0.", "To find the value of x such that f(x) = 0, we need to solve the equation 4x^3 - 9x - 14 = 0. This is a cubic equation, and finding an exact solution can be quite difficult. However, we can use numerical methods to approximate the solution.\n\nOne such method is the Newton-Raphson method, which iteratively refines an initial guess for the root. The formula for the Newton-Raphson method is:\n\nx_new = x_old - f(x_old) / f'(x_old)\n\nwhere f'(x) is the derivative of f(x) with respect to x. First, let's find the derivative of f(x):\n\nf(x) = 4x^3 - 9x - 14\nf'(x) = 12x^2 - 9\n\nNow, let's apply the Newton-Raphson method with an initial guess of x = 1:\n\n1. x_new = 1 - (4(1)^3 - 9(1) - 14) / (12(1)^2 - 9)\n x_new = 1 - (-9) / 3\n x_new = 1 + 3\n x_new = 4\n\n2. x_new = 4 - (4(4)^3 - 9(4) - 14) / (12(4)^2 - 9)\n x_new = 4 - (256 - 36 - 14) / (192 - 9)\n x_new = 4 - (206) / (183)\n x_new \u2248 2.876\n\n3. x_new \u2248 2.876 - (4(2.876)^3 - 9(2.876) - 14) / (12(2.876)^2 - 9)\n x_new \u2248 2.876 - (99.98) / (99.99)\n x_new \u2248 1.876\n\nWe can continue this process until the change in x_new is very small, indicating that we have found an approximate root. After several iterations, we find that:\n\nx \u2248 1.65\n\nSo, the approximate value of x such that f(x) = 0 is x \u2248 1.65. Note that this is an approximation, and there may be other roots as well."]}], "tstamp": 1686287765.623619} +{"question_id": 121, "answer_id": "nrgta2MB8aLEabUpEBuqbD", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python program that reads all the text files under a directory and returns the top-5 words with the most number of occurrences:\n\n```python\nimport os\nimport re\nfrom collections import Counter\n\ndef get_files_in_directory(directory, extension=\".txt\"):\n return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(extension)]\n\ndef read_file(file_path):\n with open(file_path, \"r\", encoding=\"utf-8\") as file:\n return file.read()\n\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\n\ndef main():\n directory = input(\"Enter the directory path: \")\n text_files = get_files_in_directory(directory)\n\n word_counter = Counter()\n for file_path in text_files:\n text = read_file(file_path)\n word_counter += count_words(text)\n\n top_5_words = word_counter.most_common(5)\n\n print(\"\\nTop 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\n\nif __name__ == \"__main__\":\n main()\n```\n\nTo use the program, simply run it and enter the directory path when prompted. The program will then display the top-5 words with the most number of occurrences in the text files under the specified directory.", "Yes, you can parallelize the program using Python's `concurrent.futures` module. Here's the updated version of the program that processes the text files in parallel:\n\n```python\nimport os\nimport re\nfrom collections import Counter\nfrom concurrent.futures import ProcessPoolExecutor\n\ndef get_files_in_directory(directory, extension=\".txt\"):\n return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(extension)]\n\ndef read_file(file_path):\n with open(file_path, \"r\", encoding=\"utf-8\") as file:\n return file.read()\n\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\n\ndef process_file(file_path):\n text = read_file(file_path)\n return count_words(text)\n\ndef main():\n directory = input(\"Enter the directory path: \")\n text_files = get_files_in_directory(directory)\n\n word_counter = Counter()\n\n with ProcessPoolExecutor() as executor:\n results = executor.map(process_file, text_files)\n\n for result in results:\n word_counter += result\n\n top_5_words = word_counter.most_common(5)\n\n print(\"\\nTop 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\n\nif __name__ == \"__main__\":\n main()\n```\n\nIn this version of the program, we use a `ProcessPoolExecutor` to process the text files in parallel. The `process_file` function reads and counts the words in a single file, and the results are combined in the main function using the `+=` operator on the `Counter` objects."]}], "tstamp": 1686287847.5641448} +{"question_id": 122, "answer_id": "k3DDMNG9SFBxdzBr6g3VRF", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a C++ program to find the nth Fibonacci number using recursion:\n\n```cpp\n#include \n\n// Function to find the nth Fibonacci number using recursion\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n\nint main() {\n int n;\n\n std::cout << \"Enter the value of n: \";\n std::cin >> n;\n\n int result = fibonacci(n);\n std::cout << \"The \" << n << \"th Fibonacci number is: \" << result << std::endl;\n\n return 0;\n}\n```\n\nTo compile and run the program, save it to a file named `fibonacci_recursion.cpp` and then use the following commands:\n\n```sh\ng++ -o fibonacci_recursion fibonacci_recursion.cpp\n./fibonacci_recursion\n```\n\nEnter the value of `n` when prompted, and the program will output the nth Fibonacci number. Note that this recursive approach can be slow for large values of `n` due to the repeated calculations. A more efficient approach would be to use dynamic programming or memoization.", "Here's a C++ program to find the nth number in the sequence where each number is the sum of the three preceding ones, with the first three numbers being 0, -1, and -1:\n\n```cpp\n#include \n#include \n\n// Function to find the nth number in the sequence\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n\n for (int i = 3; i < n; ++i) {\n sequence.push_back(sequence[i - 1] + sequence[i - 2] + sequence[i - 3]);\n }\n\n return sequence[n - 1];\n}\n\nint main() {\n int n;\n\n std::cout << \"Enter the value of n: \";\n std::cin >> n;\n\n int result = find_nth_number(n);\n std::cout << \"The \" << n << \"th number in the sequence is: \" << result << std::endl;\n\n return 0;\n}\n```\n\nTo compile and run the program, save it to a file named `sequence.cpp` and then use the following commands:\n\n```sh\ng++ -o sequence sequence.cpp\n./sequence\n```\n\nEnter the value of `n` when prompted, and the program will output the nth number in the sequence. This approach uses a loop and a vector to store the sequence, making it more efficient than a recursive solution."]}], "tstamp": 1686287917.543315} +{"question_id": 123, "answer_id": "VAG9qBBc46KaftN4xTmcQJ", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["\n\n\n \n \n Random Joke Generator\n \n \n\n\n

Random Joke Generator

\n \n

\n\n", "You can use CSS to change the color of the jokes to red by adding a style rule for the paragraph element where the joke is displayed. In the `\n \n\n\n

Random Joke Generator

\n \n

\n\n\n```\n\nNow, when a joke is displayed, its text color will be red."]}], "tstamp": 1686288005.79547} +{"question_id": 124, "answer_id": "9Zpf8G2Fzt6nGJubsPLcRz", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The provided function appears to be correct. It uses dynamic programming to find the length of the longest common subsequence of two input strings, `str1` and `str2`. The function initializes a 2D array `dp` of size `(m+1) x (n+1)` where `m` and `n` are the lengths of the input strings. It then iterates through the characters of both strings and fills the `dp` array according to the longest common subsequence found so far. Finally, it returns the value at `dp[m][n]`, which represents the length of the longest common subsequence of the two input strings.", "This function is also correct. It is another implementation of the dynamic programming approach to find the length of the longest common subsequence of two input strings, `X` and `Y`. The function initializes a 2D array `dp` of size `(m+1) x (n+1)` where `m` and `n` are the lengths of the input strings. It then iterates through the characters of both strings and fills the `dp` array according to the longest common subsequence found so far. Finally, it returns the value at `dp[m][n]`, which represents the length of the longest common subsequence of the two input strings.\n\nHowever, there is a small issue in this function: the first row and the first column of the `dp` array are not initialized to 0. To fix this, you can add the following lines of code before the nested loops:\n\n```python\nfor i in range(m+1):\n dp[i][0] = 0\n\nfor j in range(n+1):\n dp[0][j] = 0\n```\n\nWith these changes, the function should work correctly."]}], "tstamp": 1686288048.3319974} +{"question_id": 125, "answer_id": "P8BsTS9cZtrPYD59ky4cPw", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the highest common ancestor (HCA) of two nodes in a binary tree, you can use a recursive approach. Here's a Python function to achieve this:\n\n```python\nclass TreeNode:\n def __init__(self, value=0, left=None, right=None):\n self.value = value\n self.left = left\n self.right = right\n\ndef find_hca(root, node1, node2):\n if root is None:\n return None\n\n if root.value == node1.value or root.value == node2.value:\n return root\n\n left_hca = find_hca(root.left, node1, node2)\n right_hca = find_hca(root.right, node1, node2)\n\n if left_hca and right_hca:\n return root\n\n return left_hca if left_hca is not None else right_hca\n\n# Example usage:\n# Construct a binary tree\n# 3\n# / \\\n# 5 1\n# / \\ / \\\n# 6 2 0 8\n# / \\\n# 7 4\nroot = TreeNode(3)\nroot.left = TreeNode(5)\nroot.right = TreeNode(1)\nroot.left.left = TreeNode(6)\nroot.left.right = TreeNode(2)\nroot.right.left = TreeNode(0)\nroot.right.right = TreeNode(8)\nroot.left.right.left = TreeNode(7)\nroot.left.right.right = TreeNode(4)\n\nnode1 = root.left.left # Node with value 6\nnode2 = root.left.right.right # Node with value 4\n\nhca = find_hca(root, node1, node2)\nprint(\"Highest Common Ancestor value:\", hca.value) # Output: 5\n```\n\nThis function uses a depth-first search to traverse the binary tree. It checks if the current node is one of the target nodes, and if so, returns the current node. If not, it continues the search in the left and right subtrees. If both subtrees return a non-null value, it means the current node is the HCA. Otherwise, the function returns the non-null subtree result.", "If it's not a binary tree but a general tree with multiple children per node, you can modify the function to handle this case. Here's a Python function to find the highest common ancestor (HCA) of two nodes in a general tree:\n\n```python\nclass TreeNode:\n def __init__(self, value=0, children=None):\n self.value = value\n self.children = children if children is not None else []\n\ndef find_hca(root, node1, node2):\n if root is None:\n return None\n\n if root.value == node1.value or root.value == node2.value:\n return root\n\n hca_list = []\n for child in root.children:\n hca_child = find_hca(child, node1, node2)\n if hca_child:\n hca_list.append(hca_child)\n\n if len(hca_list) >= 2:\n return root\n elif len(hca_list) == 1:\n return hca_list[0]\n else:\n return None\n\n# Example usage:\n# Construct a general tree\n# 1\n# / | \\\n# 2 3 4\n# /| | |\\\n# 5 6 7 8 9\n# / \\\n# 10 11\nroot = TreeNode(1, [\n TreeNode(2, [\n TreeNode(5),\n TreeNode(6)\n ]),\n TreeNode(3, [\n TreeNode(7)\n ]),\n TreeNode(4, [\n TreeNode(8, [\n TreeNode(10),\n TreeNode(11)\n ]),\n TreeNode(9)\n ])\n])\n\nnode1 = root.children[0].children[1] # Node with value 6\nnode2 = root.children[2].children[0].children[1] # Node with value 11\n\nhca = find_hca(root, node1, node2)\nprint(\"Highest Common Ancestor value:\", hca.value) # Output: 1\n```\n\nThis function is similar to the binary tree version, but instead of checking only left and right children, it iterates through all children of the current node. If two or more children return a non-null value, it means the current node is the HCA. Otherwise, the function returns the non-null subtree result."]}], "tstamp": 1686288164.937362} +{"question_id": 126, "answer_id": "EURwBVADp4nAoxMuetcicm", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["It's not possible to find the median of two sorted arrays of different sizes with O(1) space complexity and O(n) time complexity. However, you can achieve O(log(min(n, m))) time complexity using a binary search algorithm. Here's a Python implementation of such a function:\n\n```python\ndef findMedianSortedArrays(nums1, nums2):\n if len(nums1) > len(nums2):\n nums1, nums2 = nums2, nums1\n\n x, y = len(nums1), len(nums2)\n low, high = 0, x\n\n while low <= high:\n partition_x = (low + high) // 2\n partition_y = (x + y + 1) // 2 - partition_x\n\n max_left_x = float('-inf') if partition_x == 0 else nums1[partition_x - 1]\n min_right_x = float('inf') if partition_x == x else nums1[partition_x]\n\n max_left_y = float('-inf') if partition_y == 0 else nums2[partition_y - 1]\n min_right_y = float('inf') if partition_y == y else nums2[partition_y]\n\n if max_left_x <= min_right_y and max_left_y <= min_right_x:\n if (x + y) % 2 == 0:\n return (max(max_left_x, max_left_y) + min(min_right_x, min_right_y)) / 2\n else:\n return max(max_left_x, max_left_y)\n elif max_left_x > min_right_y:\n high = partition_x - 1\n else:\n low = partition_x + 1\n\n raise ValueError(\"Input arrays are not sorted.\")\n```\n\nThis function takes two sorted arrays `nums1` and `nums2` as input and returns the median value. The time complexity of this function is O(log(min(n, m))), where n and m are the lengths of the input arrays. The space complexity is O(1) as it only uses a constant amount of extra space.", "The implementation provided above with a time complexity of O(log(min(n, m))) is already an efficient solution for finding the median of two sorted arrays. This is because the problem of finding the median of two sorted arrays can be reduced to finding the k-th element in the merged array, where k is the middle index. The binary search algorithm used in the implementation takes advantage of the sorted nature of the input arrays and efficiently narrows down the search space.\n\nIn general, it is not possible to achieve a better time complexity than O(log(min(n, m))) for this problem, as any algorithm that solves this problem must examine at least a logarithmic number of elements in the smaller array to determine the median. This is because the median depends on the relative ordering of the elements in both arrays, and a logarithmic number of comparisons is required to establish this ordering.\n\nTherefore, the provided implementation with a time complexity of O(log(min(n, m))) is already an efficient and optimal solution for finding the median of two sorted arrays."]}], "tstamp": 1686288249.211752} +{"question_id": 127, "answer_id": "eYa4BTZvWrYTHhTdzX7E6g", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python function that finds the majority element in a given integer array using the Boyer-Moore Voting Algorithm:\n\n```python\ndef majority_element(arr):\n candidate = None\n count = 0\n\n # Boyer-Moore Voting Algorithm\n for num in arr:\n if count == 0:\n candidate = num\n count += (1 if num == candidate else -1)\n\n # Verify if the candidate is indeed the majority element\n if arr.count(candidate) > len(arr) // 2:\n return candidate\n else:\n return None\n\n# Example usage:\narr = [2, 2, 1, 1, 1, 2, 2]\nprint(majority_element(arr)) # Output: 2\n```\n\nThis function first finds a candidate for the majority element using the Boyer-Moore Voting Algorithm. Then, it verifies if the candidate is indeed the majority element by counting its occurrences in the array. If the candidate occurs more than half of the array's length, it is returned as the majority element; otherwise, the function returns None.", "To find the top-2 most occurring elements in a given integer array, you can modify the Boyer-Moore Voting Algorithm. Here's a Python function that does this:\n\n```python\ndef top_two_elements(arr):\n candidate1, candidate2 = None, None\n count1, count2 = 0, 0\n\n # Modified Boyer-Moore Voting Algorithm\n for num in arr:\n if num == candidate1:\n count1 += 1\n elif num == candidate2:\n count2 += 1\n elif count1 == 0:\n candidate1 = num\n count1 = 1\n elif count2 == 0:\n candidate2 = num\n count2 = 1\n else:\n count1 -= 1\n count2 -= 1\n\n # Verify if the candidates are indeed the top-2 most occurring elements\n count1 = arr.count(candidate1)\n count2 = arr.count(candidate2)\n\n if count1 > count2:\n return candidate1, candidate2\n else:\n return candidate2, candidate1\n\n# Example usage:\narr = [1, 1, 2, 2, 3, 3, 1, 2, 3, 4, 5]\nprint(top_two_elements(arr)) # Output: (1, 2)\n```\n\nThis function first finds two candidates for the top-2 most occurring elements using a modified version of the Boyer-Moore Voting Algorithm. Then, it verifies the counts of the candidates in the array and returns them in descending order of their counts. Note that this function assumes that there are at least two distinct elements in the array."]}], "tstamp": 1686288322.0523534} +{"question_id": 128, "answer_id": "BMZdfLQQke8d9j78rSmLsf", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find B_n, we can use the following recursive formula:\n\nB_n = sum(B_i * B_(n-1-i)) for i = 0 to (n-2)/2, where n is even.\n\nHere's a Python function to implement this:\n\n```python\ndef full_binary_trees(n):\n if n % 2 == 0 or n < 1:\n return 0\n if n == 1:\n return 1\n\n B = [0] * (n + 1)\n B[1] = 1\n\n for i in range(3, n + 1, 2):\n for j in range(0, i - 1, 2):\n B[i] += B[j] * B[i - 1 - j]\n\n return B[n]\n\n# Example usage:\nn = 5\nprint(f\"B_{n} =\", full_binary_trees(n))\n```\n\nThis function first checks if n is even or less than 1, in which case it returns 0 since there are no full binary trees with an even number of vertices or with a negative number of vertices. If n is 1, it returns 1 since there is only one full binary tree with a single vertex (the root). Then, it initializes an array B of size n+1 to store the number of full binary trees for each value of n. It iterates through the odd numbers from 3 to n, and for each i, it calculates B_i using the recursive formula mentioned above. Finally, it returns B_n.", "If the problem changes to a ternary tree, we can modify the recursive formula as follows:\n\nT_n = sum(T_i * T_(j) * T_(n-1-i-j)) for i = 0 to (n-2)/2, j = 0 to (n-2)/2, where n is a multiple of 3.\n\nHere's a Python function to implement this for ternary trees:\n\n```python\ndef full_ternary_trees(n):\n if n % 3 != 0 or n < 1:\n return 0\n if n == 1:\n return 1\n\n T = [0] * (n + 1)\n T[1] = 1\n\n for i in range(3, n + 1, 3):\n for j in range(0, i - 1, 3):\n for k in range(0, i - 1 - j, 3):\n T[i] += T[j] * T[k] * T[i - 1 - j - k]\n\n return T[n]\n\n# Example usage:\nn = 9\nprint(f\"T_{n} =\", full_ternary_trees(n))\n```\n\nThis function first checks if n is not a multiple of 3 or less than 1, in which case it returns 0 since there are no full ternary trees with a number of vertices that is not a multiple of 3 or with a negative number of vertices. If n is 1, it returns 1 since there is only one full ternary tree with a single vertex (the root). Then, it initializes an array T of size n+1 to store the number of full ternary trees for each value of n. It iterates through the multiples of 3 from 3 to n, and for each i, it calculates T_i using the modified recursive formula mentioned above. Finally, it returns T_n."]}], "tstamp": 1686288533.7038336} +{"question_id": 129, "answer_id": "62ofoso88i9zStspxDdUER", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the kth smallest element in the union of two sorted lists with linear complexity, you can use a two-pointer approach. Here's a Python function that implements this approach:\n\n```python\ndef find_kth_smallest(arr1, arr2, k):\n m, n = len(arr1), len(arr2)\n i, j = 0, 0\n\n while i < m and j < n:\n if arr1[i] < arr2[j]:\n k -= 1\n if k == 0:\n return arr1[i]\n i += 1\n else:\n k -= 1\n if k == 0:\n return arr2[j]\n j += 1\n\n while i < m:\n k -= 1\n if k == 0:\n return arr1[i]\n i += 1\n\n while j < n:\n k -= 1\n if k == 0:\n return arr2[j]\n j += 1\n\n return None\n```\n\nThis function takes two sorted lists `arr1` and `arr2`, and an integer `k` as input. It initializes two pointers `i` and `j` to traverse the two lists. The function iterates through both lists, comparing the elements at the current positions of the pointers. If the element in `arr1` is smaller, it increments the pointer `i`. If the element in `arr2` is smaller, it increments the pointer `j`. In each iteration, the function also decrements the value of `k`. When `k` reaches 0, the function returns the current element.\n\nIf one of the lists is exhausted, the function continues iterating through the remaining list until `k` reaches 0.\n\nThe time complexity of this function is O(m + n), which is linear in the size of the input lists.", "Yes, there exists an algorithm with better time complexity. We can use a binary search approach to find the kth smallest element in the union of the two sorted lists with logarithmic complexity. Here's a Python function that implements this approach:\n\n```python\ndef find_kth_smallest(arr1, arr2, k):\n def kth(A, B, k):\n if len(A) == 0:\n return B[k - 1]\n if len(B) == 0:\n return A[k - 1]\n if k == 1:\n return min(A[0], B[0])\n\n mid_A, mid_B = len(A) // 2, len(B) // 2\n if A[mid_A] <= B[mid_B]:\n if k > mid_A + mid_B + 1:\n return kth(A[mid_A + 1:], B, k - mid_A - 1)\n else:\n return kth(A, B[:mid_B], k)\n else:\n if k > mid_A + mid_B + 1:\n return kth(A, B[mid_B + 1:], k - mid_B - 1)\n else:\n return kth(A[:mid_A], B, k)\n\n return kth(arr1, arr2, k)\n```\n\nThis function takes two sorted lists `arr1` and `arr2`, and an integer `k` as input. It defines a helper function `kth` that takes two lists `A` and `B`, and an integer `k`. The helper function performs a binary search on both lists to find the kth smallest element.\n\nThe time complexity of this function is O(log(m) + log(n)), which is logarithmic in the size of the input lists. This is better than the linear complexity of the previous algorithm."]}], "tstamp": 1686288619.2822871} +{"question_id": 130, "answer_id": "B39HvvzTyw5gGQKqNFH6Ed", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["You can implement a program to find the common elements in two arrays without using any extra data structures by using nested loops. Here's an example in Python:\n\n```python\ndef find_common_elements(arr1, arr2):\n # Iterate through each element in the first array\n for i in range(len(arr1)):\n # Iterate through each element in the second array\n for j in range(len(arr2)):\n # If the elements match, print the common element\n if arr1[i] == arr2[j]:\n print(arr1[i], end=\" \")\n\n# Example usage\narr1 = [1, 2, 3, 4, 5]\narr2 = [3, 4, 5, 6, 7]\nfind_common_elements(arr1, arr2)\n```\n\nThis program will output:\n\n```\n3 4 5\n```\n\nKeep in mind that this approach has a time complexity of O(n^2) due to the nested loops. If you were allowed to use extra data structures, you could improve the time complexity by using a set or a hash table.", "Now that we can use extra data structures, we can use a set to store the elements of one array and then check if the elements of the other array are present in the set. This will give us a better time complexity of O(n). Here's an example in Python:\n\n```python\ndef find_common_elements(arr1, arr2):\n # Create a set with the elements of the first array\n elements_set = set(arr1)\n\n # Iterate through each element in the second array\n for element in arr2:\n # If the element is present in the set, print the common element\n if element in elements_set:\n print(element, end=\" \")\n\n# Example usage\narr1 = [1, 2, 3, 4, 5]\narr2 = [3, 4, 5, 6, 7]\nfind_common_elements(arr1, arr2)\n```\n\nThis program will output:\n\n```\n3 4 5\n```\n\nUsing a set allows us to achieve a time complexity of O(n) because set operations (add, remove, and search) have an average time complexity of O(1)."]}], "tstamp": 1686288677.5436106} diff --git a/fastchat/llm_judge/data/vicuna_bench/question.jsonl b/fastchat/llm_judge/data/vicuna_bench/question.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..223642384936578e66cf127936251dda6a51d595 --- /dev/null +++ b/fastchat/llm_judge/data/vicuna_bench/question.jsonl @@ -0,0 +1,80 @@ +{"question_id": 1, "category": "generic", "turns": ["How can I improve my time management skills?"]} +{"question_id": 2, "category": "generic", "turns": ["What are the most effective ways to deal with stress?"]} +{"question_id": 3, "category": "generic", "turns": ["What are the main differences between Python and JavaScript programming languages?"]} +{"question_id": 4, "category": "generic", "turns": ["How can I increase my productivity while working from home?"]} +{"question_id": 5, "category": "generic", "turns": ["Can you explain the basics of quantum computing?"]} +{"question_id": 6, "category": "generic", "turns": ["What are the differences between plant-based and animal-based protein sources?"]} +{"question_id": 7, "category": "generic", "turns": ["How can I develop my critical thinking skills?"]} +{"question_id": 8, "category": "generic", "turns": ["What are the major challenges faced by the education sector today?"]} +{"question_id": 9, "category": "generic", "turns": ["What are the primary factors that influence consumer behavior?"]} +{"question_id": 10, "category": "generic", "turns": ["What are the most effective strategies for conflict resolution in the workplace?"]} +{"question_id": 11, "category": "knowledge", "turns": ["What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?"]} +{"question_id": 12, "category": "knowledge", "turns": ["What factors would you consider when designing an inclusive and accessible public transportation system?"]} +{"question_id": 13, "category": "knowledge", "turns": ["How can governments utilize fiscal and monetary policies to combat economic recessions?"]} +{"question_id": 14, "category": "knowledge", "turns": ["How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?"]} +{"question_id": 15, "category": "knowledge", "turns": ["Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery."]} +{"question_id": 16, "category": "knowledge", "turns": ["Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications."]} +{"question_id": 17, "category": "knowledge", "turns": ["How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?"]} +{"question_id": 18, "category": "knowledge", "turns": ["How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?"]} +{"question_id": 19, "category": "knowledge", "turns": ["How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?"]} +{"question_id": 20, "category": "knowledge", "turns": ["Explain the process of natural selection and how it contributes to the evolution and adaptation of species."]} +{"question_id": 21, "category": "roleplay", "turns": ["How would you introduce yourself as a medieval knight at a royal banquet?"]} +{"question_id": 22, "category": "roleplay", "turns": ["As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?"]} +{"question_id": 23, "category": "roleplay", "turns": ["If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?"]} +{"question_id": 24, "category": "roleplay", "turns": ["As a superhero, how would you explain your origin story to a curious child?"]} +{"question_id": 25, "category": "roleplay", "turns": ["Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?"]} +{"question_id": 26, "category": "roleplay", "turns": ["As a sports commentator, describe the winning play in the final seconds of a championship game."]} +{"question_id": 27, "category": "roleplay", "turns": ["Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?"]} +{"question_id": 28, "category": "roleplay", "turns": ["You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top."]} +{"question_id": 29, "category": "roleplay", "turns": ["As a space colonist on Mars, describe your daily life and the challenges you face living on another planet."]} +{"question_id": 30, "category": "roleplay", "turns": ["Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter."]} +{"question_id": 31, "category": "common-sense", "turns": ["How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?"]} +{"question_id": 32, "category": "common-sense", "turns": ["What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?"]} +{"question_id": 33, "category": "common-sense", "turns": ["Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?"]} +{"question_id": 34, "category": "common-sense", "turns": ["How can you determine if a person is genuinely interested in a conversation or simply being polite?"]} +{"question_id": 35, "category": "common-sense", "turns": ["Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?"]} +{"question_id": 36, "category": "common-sense", "turns": ["How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?"]} +{"question_id": 37, "category": "common-sense", "turns": ["Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?"]} +{"question_id": 38, "category": "common-sense", "turns": ["How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?"]} +{"question_id": 39, "category": "common-sense", "turns": ["Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?"]} +{"question_id": 40, "category": "common-sense", "turns": ["In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?"]} +{"question_id": 41, "category": "fermi", "turns": ["How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 42, "category": "fermi", "turns": ["How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 43, "category": "fermi", "turns": ["How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 44, "category": "fermi", "turns": ["How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 45, "category": "fermi", "turns": ["How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 46, "category": "fermi", "turns": ["How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 47, "category": "fermi", "turns": ["How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 48, "category": "fermi", "turns": ["How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 49, "category": "fermi", "turns": ["How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 50, "category": "fermi", "turns": ["How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 51, "category": "counterfactual", "turns": ["What if the Internet had been invented during the Renaissance period?"]} +{"question_id": 52, "category": "counterfactual", "turns": ["What if the Aztecs had successfully repelled the Spanish conquistadors?"]} +{"question_id": 53, "category": "counterfactual", "turns": ["What if the Black Death had not occurred in the 14th century?"]} +{"question_id": 54, "category": "counterfactual", "turns": ["What if Isaac Newton had focused on biology instead of physics?"]} +{"question_id": 55, "category": "counterfactual", "turns": ["What if the Beatles had never formed as a band?"]} +{"question_id": 56, "category": "counterfactual", "turns": ["What if Alan Turing had not cracked the Enigma code during World War II?"]} +{"question_id": 57, "category": "counterfactual", "turns": ["What if the Suez Canal had never been constructed?"]} +{"question_id": 58, "category": "counterfactual", "turns": ["What if the Maya civilization had never mysteriously collapsed?"]} +{"question_id": 59, "category": "counterfactual", "turns": ["What if Christopher Columbus had not discovered the Americas?"]} +{"question_id": 60, "category": "counterfactual", "turns": ["What if Vincent van Gogh had been a successful artist during his lifetime?"]} +{"question_id": 61, "category": "coding", "turns": ["Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file."]} +{"question_id": 62, "category": "coding", "turns": ["Implement a Python function to find the longest common subsequence of two input strings using dynamic programming."]} +{"question_id": 63, "category": "coding", "turns": ["Implement a regular expression in Python to validate an email address."]} +{"question_id": 64, "category": "coding", "turns": ["Write a program to find the nth Fibonacci number using dynamic programming."]} +{"question_id": 65, "category": "coding", "turns": ["Implement a binary search algorithm to find a specific element in a sorted array."]} +{"question_id": 66, "category": "coding", "turns": ["Implement a queue data structure using two stacks in Python."]} +{"question_id": 67, "category": "coding", "turns": ["Implement a program to find the common elements in two arrays without using any extra data structures."]} +{"question_id": 68, "category": "math", "turns": ["Given that f(x) = 5x^3 - 2x + 3, find the value of f(2)."]} +{"question_id": 69, "category": "math", "turns": ["Solve for x in the equation 3x + 10 = 5(x - 2)."]} +{"question_id": 70, "category": "math", "turns": ["If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?"]} +{"question_id": 71, "category": "writing", "turns": ["Can you help me write a formal email to a potential business partner proposing a joint venture?"]} +{"question_id": 72, "category": "writing", "turns": ["Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?"]} +{"question_id": 73, "category": "writing", "turns": ["Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science."]} +{"question_id": 74, "category": "writing", "turns": ["Write a compelling product launch announcement email to inform our customers of our new software solution."]} +{"question_id": 75, "category": "writing", "turns": ["Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved."]} +{"question_id": 76, "category": "writing", "turns": ["Write a script for a YouTube video exploring the history and cultural significance of jazz."]} +{"question_id": 77, "category": "writing", "turns": ["Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."]} +{"question_id": 78, "category": "writing", "turns": ["Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects."]} +{"question_id": 79, "category": "writing", "turns": ["Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry."]} +{"question_id": 80, "category": "writing", "turns": ["Write a symphony concert review, discussing the orchestra's performance and overall audience experience."]} diff --git a/fastchat/llm_judge/data/vicuna_bench/reference_answer/gpt-4.jsonl b/fastchat/llm_judge/data/vicuna_bench/reference_answer/gpt-4.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..f3c8368d058cd01912c41b9ca9e67ce82867bdf3 --- /dev/null +++ b/fastchat/llm_judge/data/vicuna_bench/reference_answer/gpt-4.jsonl @@ -0,0 +1,10 @@ +{"question_id": 61, "answer_id": "YdL4XwENkLCLXMbH65rjKy", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a simple C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file:\n\n```cpp\n#include \n#include \n#include \n#include \n\nint countWordOccurrences(const std::string& line, const std::string& word) {\n std::istringstream iss(line);\n std::string currentWord;\n int count = 0;\n\n while (iss >> currentWord) {\n if (currentWord == word) {\n count++;\n }\n }\n\n return count;\n}\n\nint main() {\n std::string fileName, wordToCount;\n std::cout << \"Enter the file name: \";\n std::cin >> fileName;\n std::cout << \"Enter the word to count: \";\n std::cin >> wordToCount;\n\n std::ifstream inputFile(fileName);\n if (!inputFile) {\n std::cerr << \"Error opening the file: \" << fileName << std::endl;\n return 1;\n }\n\n std::string line;\n int totalCount = 0;\n while (std::getline(inputFile, line)) {\n totalCount += countWordOccurrences(line, wordToCount);\n }\n\n inputFile.close();\n\n std::cout << \"The word \\\"\" << wordToCount << \"\\\" occurs \" << totalCount << \" times in the file.\" << std::endl;\n\n return 0;\n}\n```\n\nTo compile and run the program, save it to a file named `word_counter.cpp` and use the following commands:\n\n```sh\ng++ -o word_counter word_counter.cpp\n./word_counter\n```\n\nThe program will prompt you to enter the file name and the word you want to count. After that, it will display the number of occurrences of the specified word in the file."]}], "tstamp": 1686545085.5954556} +{"question_id": 62, "answer_id": "4nk9zQkeRKvfjxEEqfyXmj", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python function that finds the longest common subsequence of two input strings using dynamic programming:\n\n```python\ndef longest_common_subsequence(s1, s2):\n # Create a matrix to store the lengths of common subsequences\n matrix = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)]\n\n # Iterate through the strings and fill the matrix\n for i in range(len(s1) - 1, -1, -1):\n for j in range(len(s2) - 1, -1, -1):\n if s1[i] == s2[j]:\n matrix[i][j] = 1 + matrix[i + 1][j + 1]\n else:\n matrix[i][j] = max(matrix[i + 1][j], matrix[i][j + 1])\n\n # Reconstruct the longest common subsequence from the matrix\n lcs = []\n i, j = 0, 0\n while i < len(s1) and j < len(s2):\n if s1[i] == s2[j]:\n lcs.append(s1[i])\n i += 1\n j += 1\n elif matrix[i + 1][j] > matrix[i][j + 1]:\n i += 1\n else:\n j += 1\n\n return ''.join(lcs)\n\n# Example usage:\ns1 = \"ABCDGH\"\ns2 = \"AEDFHR\"\nprint(longest_common_subsequence(s1, s2)) # Output: \"ADH\"\n```\n\nThis function uses dynamic programming to build a matrix that stores the lengths of common subsequences for each pair of substrings. It then reconstructs the longest common subsequence from the matrix."]}], "tstamp": 1686544775.161658} +{"question_id": 63, "answer_id": "2MkXEfsoSuSC72BYGJuqhX", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To implement a regular expression in Python to validate an email address, you can use the `re` module. Here's a simple example:\n\n```python\nimport re\n\ndef validate_email(email):\n pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n return bool(re.match(pattern, email))\n\n# Test cases\nemails = [\n \"test@example.com\",\n \"test@example.co.uk\",\n \"test+123@example.com\",\n \"test@subdomain.example.com\",\n \"test@sub_domain.example.com\",\n \"test@sub-domain.example.com\",\n \"test@example\",\n \"test@.com\",\n \"test@.example.com\",\n \"test@example..com\",\n \"test@.example..com\",\n \"test@example.com.\",\n \".test@example.com\",\n]\n\nfor email in emails:\n print(f\"{email}: {validate_email(email)}\")\n```\n\nThis regular expression checks for the following conditions:\n\n1. The email address starts with one or more alphanumeric characters, dots, underscores, percent signs, plus signs, or hyphens.\n2. Followed by the '@' symbol.\n3. Followed by one or more alphanumeric characters, dots, or hyphens.\n4. Followed by a dot.\n5. Ends with two or more alphabetic characters.\n\nPlease note that this is a simple example and may not cover all edge cases. You can modify the regular expression to better suit your specific requirements."]}], "tstamp": 1686544778.904383} +{"question_id": 64, "answer_id": "jPeywxk9g4LWyLPurjhsd2", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python program to find the nth Fibonacci number using dynamic programming (also known as memoization):\n\n```python\ndef fibonacci(n, memo={}):\n if n <= 0:\n return 0\n elif n == 1:\n return 1\n elif n not in memo:\n memo[n] = fibonacci(n - 1, memo) + fibonacci(n - 2, memo)\n return memo[n]\n\n# Test the function\nn = int(input(\"Enter the value of n: \"))\nresult = fibonacci(n)\nprint(f\"The {n}th Fibonacci number is: {result}\")\n```\n\nThis program defines a function `fibonacci(n, memo={})` that calculates the nth Fibonacci number using dynamic programming. The `memo` dictionary is used to store the results of previously calculated Fibonacci numbers, which helps to avoid redundant calculations and improve the efficiency of the program."]}], "tstamp": 1686544763.1005273} +{"question_id": 65, "answer_id": "hD8NiB3PY5aoWg95mNjnj3", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python implementation of the binary search algorithm to find a specific element in a sorted array:\n\n```python\ndef binary_search(arr, target):\n low, high = 0, len(arr) - 1\n\n while low <= high:\n mid = (low + high) // 2\n mid_val = arr[mid]\n\n if mid_val == target:\n return mid\n elif mid_val < target:\n low = mid + 1\n else:\n high = mid - 1\n\n return -1\n\n# Example usage:\narr = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]\ntarget = 11\n\nresult = binary_search(arr, target)\nif result != -1:\n print(f\"Element {target} found at index {result}\")\nelse:\n print(f\"Element {target} not found in the array\")\n```\n\nThis binary search function takes a sorted array `arr` and a target value `target` as input. It initializes two pointers, `low` and `high`, to the first and last indices of the array, respectively. The function then iteratively updates the pointers and calculates the middle index `mid` until the target value is found or the search space is exhausted. If the target value is found, the function returns the index of the target value; otherwise, it returns -1."]}], "tstamp": 1686544779.33291} +{"question_id": 66, "answer_id": "c4rAXsTfmioefx6icubuFB", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To implement a queue using two stacks in Python, you can use the following code:\n\n```python\nclass Queue:\n def __init__(self):\n self.stack1 = []\n self.stack2 = []\n\n def enqueue(self, item):\n self.stack1.append(item)\n\n def dequeue(self):\n if not self.stack2:\n while self.stack1:\n self.stack2.append(self.stack1.pop())\n return self.stack2.pop()\n\n def is_empty(self):\n return not self.stack1 and not self.stack2\n\n def size(self):\n return len(self.stack1) + len(self.stack2)\n\n\n# Example usage:\nqueue = Queue()\nqueue.enqueue(1)\nqueue.enqueue(2)\nqueue.enqueue(3)\n\nprint(queue.dequeue()) # Output: 1\nprint(queue.dequeue()) # Output: 2\n\nqueue.enqueue(4)\nqueue.enqueue(5)\n\nprint(queue.dequeue()) # Output: 3\nprint(queue.dequeue()) # Output: 4\nprint(queue.dequeue()) # Output: 5\n```\n\nIn this implementation, we use two stacks (`stack1` and `stack2`) to simulate the behavior of a queue. When enqueuing an item, we simply push it onto `stack1`. When dequeuing an item, we pop from `stack2`. If `stack2` is empty, we first transfer all items from `stack1` to `stack2` (reversing their order) before popping from `stack2`. This ensures that the oldest item in the queue is always at the top of `stack2`."]}], "tstamp": 1686544798.8350058} +{"question_id": 67, "answer_id": "UQ3TFrzVqXtMTWpvhqtGym", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["You can implement a program to find the common elements in two arrays without using any extra data structures by using nested loops. Here's an example in Python:\n\n```python\ndef find_common_elements(arr1, arr2):\n common_elements = []\n\n for i in arr1:\n for j in arr2:\n if i == j:\n common_elements.append(i)\n break\n\n return common_elements\n\n# Example usage:\narr1 = [1, 2, 3, 4, 5]\narr2 = [4, 5, 6, 7, 8]\n\ncommon_elements = find_common_elements(arr1, arr2)\nprint(\"Common elements:\", common_elements)\n```\n\nThis program defines a function `find_common_elements` that takes two arrays as input. It initializes an empty list `common_elements` to store the common elements found. Then, it iterates through each element in the first array using a loop, and for each element, it iterates through the second array using another loop. If an element from the first array is found in the second array, it is added to the `common_elements` list and the inner loop is terminated using the `break` statement.\n\nFinally, the function returns the `common_elements` list. In the example usage, the program finds the common elements between two arrays `arr1` and `arr2`, and prints the result."]}], "tstamp": 1686544805.0822344} +{"question_id": 68, "answer_id": "eBGY6U9KCAsfuedcASSK7Q", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the value of f(2), substitute x = 2 into the function f(x) = 5x^3 - 2x + 3.\n\nf(2) = 5(2)^3 - 2(2) + 3\nf(2) = 5(8) - 4 + 3\nf(2) = 40 - 4 + 3\nf(2) = 36 + 3\nf(2) = 39\n\nSo, the value of f(2) is 39."]}], "tstamp": 1686544893.1474469} +{"question_id": 69, "answer_id": "KsgnWzbpEvoW7RmRwaLgMT", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To solve for x in the equation 3x + 10 = 5(x - 2), follow these steps:\n\n1. Distribute the 5 on the right side of the equation:\n3x + 10 = 5x - 10\n\n2. Move the x terms to one side by subtracting 3x from both sides:\n10 = 2x - 10\n\n3. Move the constant term to the other side by adding 10 to both sides:\n20 = 2x\n\n4. Divide both sides by 2 to isolate x:\nx = 10\n\nSo, the solution is x = 10."]}], "tstamp": 1686544795.032636} +{"question_id": 70, "answer_id": "BT82nWd8ZtXTzHrvhkg6eb", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the length of the line segment, we can use the distance formula:\n\nDistance = \u221a((x2 - x1)^2 + (y2 - y1)^2)\n\nHere, (x1, y1) = (2, -2) and (x2, y2) = (10, 4).\n\nPlugging in the values, we get:\n\nDistance = \u221a((10 - 2)^2 + (4 - (-2))^2)\nDistance = \u221a((8)^2 + (6)^2)\nDistance = \u221a(64 + 36)\nDistance = \u221a100\nDistance = 10\n\nSo, the length of the line segment is 10 units."]}], "tstamp": 1686544809.1491597} diff --git a/fastchat/llm_judge/download_mt_bench_pregenerated.py b/fastchat/llm_judge/download_mt_bench_pregenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9546311f542af7a0824cacb0d9427f62a1a586e2 --- /dev/null +++ b/fastchat/llm_judge/download_mt_bench_pregenerated.py @@ -0,0 +1,51 @@ +""" +Download the pre-generated model answers and judgments for MT-bench. +""" +import os + +from fastchat.utils import run_cmd + +filenames = [ + "data/mt_bench/model_answer/alpaca-13b.jsonl", + "data/mt_bench/model_answer/baize-v2-13b.jsonl", + "data/mt_bench/model_answer/chatglm-6b.jsonl", + "data/mt_bench/model_answer/claude-instant-v1.jsonl", + "data/mt_bench/model_answer/claude-v1.jsonl", + "data/mt_bench/model_answer/dolly-v2-12b.jsonl", + "data/mt_bench/model_answer/falcon-40b-instruct.jsonl", + "data/mt_bench/model_answer/fastchat-t5-3b.jsonl", + "data/mt_bench/model_answer/gpt-3.5-turbo.jsonl", + "data/mt_bench/model_answer/gpt-4.jsonl", + "data/mt_bench/model_answer/gpt4all-13b-snoozy.jsonl", + "data/mt_bench/model_answer/guanaco-33b.jsonl", + "data/mt_bench/model_answer/guanaco-65b.jsonl", + "data/mt_bench/model_answer/h2ogpt-oasst-open-llama-13b.jsonl", + "data/mt_bench/model_answer/koala-13b.jsonl", + "data/mt_bench/model_answer/llama-13b.jsonl", + "data/mt_bench/model_answer/mpt-30b-chat.jsonl", + "data/mt_bench/model_answer/mpt-30b-instruct.jsonl", + "data/mt_bench/model_answer/mpt-7b-chat.jsonl", + "data/mt_bench/model_answer/nous-hermes-13b.jsonl", + "data/mt_bench/model_answer/oasst-sft-4-pythia-12b.jsonl", + "data/mt_bench/model_answer/oasst-sft-7-llama-30b.jsonl", + "data/mt_bench/model_answer/palm-2-chat-bison-001.jsonl", + "data/mt_bench/model_answer/rwkv-4-raven-14b.jsonl", + "data/mt_bench/model_answer/stablelm-tuned-alpha-7b.jsonl", + "data/mt_bench/model_answer/tulu-30b.jsonl", + "data/mt_bench/model_answer/vicuna-13b-v1.3.jsonl", + "data/mt_bench/model_answer/vicuna-33b-v1.3.jsonl", + "data/mt_bench/model_answer/vicuna-7b-v1.3.jsonl", + "data/mt_bench/model_answer/wizardlm-13b.jsonl", + "data/mt_bench/model_answer/wizardlm-30b.jsonl", + "data/mt_bench/model_judgment/gpt-4_single.jsonl", + "data/mt_bench/model_judgment/gpt-4_pair.jsonl", +] + + +if __name__ == "__main__": + prefix = "https://huggingface.co/spaces/lmsys/mt-bench/resolve/main/" + + for name in filenames: + os.makedirs(os.path.dirname(name), exist_ok=True) + ret = run_cmd(f"wget -q --show-progress -O {name} {prefix + name}") + assert ret == 0 diff --git a/fastchat/llm_judge/gen_api_answer.py b/fastchat/llm_judge/gen_api_answer.py new file mode 100644 index 0000000000000000000000000000000000000000..8f9c626244337e9e52f404c46a84eeb0b8806cdc --- /dev/null +++ b/fastchat/llm_judge/gen_api_answer.py @@ -0,0 +1,149 @@ +"""Generate answers with GPT-4 + +Usage: +python3 gen_api_answer.py --model gpt-3.5-turbo +""" +import argparse +import json +import os +import time +import concurrent.futures + +import openai +import shortuuid +import tqdm + +from fastchat.llm_judge.common import ( + load_questions, + temperature_config, + chat_completion_openai, + chat_completion_anthropic, + chat_completion_palm, +) +from fastchat.llm_judge.gen_model_answer import reorg_answer_file +from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST + + +def get_answer( + question: dict, model: str, num_choices: int, max_tokens: int, answer_file: str +): + assert ( + args.force_temperature is not None and "required_temperature" in question.keys() + ) == False + if args.force_temperature is not None: + temperature = args.force_temperature + elif "required_temperature" in question.keys(): + temperature = question["required_temperature"] + elif question["category"] in temperature_config: + temperature = temperature_config[question["category"]] + else: + temperature = 0.7 + + choices = [] + chat_state = None # for palm-2 model + for i in range(num_choices): + conv = get_conversation_template(model) + + turns = [] + for j in range(len(question["turns"])): + conv.append_message(conv.roles[0], question["turns"][j]) + conv.append_message(conv.roles[1], None) + + if model in ANTHROPIC_MODEL_LIST: + output = chat_completion_anthropic(model, conv, temperature, max_tokens) + elif model == "palm-2-chat-bison-001": + chat_state, output = chat_completion_palm( + chat_state, model, conv, temperature, max_tokens + ) + else: + output = chat_completion_openai(model, conv, temperature, max_tokens) + + conv.update_last_message(output) + turns.append(output) + + choices.append({"index": i, "turns": turns}) + + # Dump answers + ans = { + "question_id": question["question_id"], + "answer_id": shortuuid.uuid(), + "model_id": model, + "choices": choices, + "tstamp": time.time(), + } + + os.makedirs(os.path.dirname(answer_file), exist_ok=True) + with open(answer_file, "a") as fout: + fout.write(json.dumps(ans) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bench-name", + type=str, + default="mt_bench", + help="The name of the benchmark question set.", + ) + parser.add_argument("--answer-file", type=str, help="The output answer file.") + parser.add_argument("--model", type=str, default="gpt-3.5-turbo") + parser.add_argument( + "--num-choices", + type=int, + default=1, + help="How many completion choices to generate.", + ) + parser.add_argument( + "--force-temperature", type=float, help="Forcibly set a sampling temperature." + ) + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="The maximum number of new generated tokens.", + ) + parser.add_argument( + "--question-begin", + type=int, + help="A debug option. The begin index of questions.", + ) + parser.add_argument( + "--question-end", type=int, help="A debug option. The end index of questions." + ) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument("--openai-api-base", type=str, default=None) + args = parser.parse_args() + + if args.openai_api_base is not None: + openai.api_base = args.openai_api_base + + question_file = f"data/{args.bench_name}/question.jsonl" + questions = load_questions(question_file, args.question_begin, args.question_end) + + if args.answer_file: + answer_file = args.answer_file + else: + answer_file = f"data/{args.bench_name}/model_answer/{args.model}.jsonl" + print(f"Output to {answer_file}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=args.parallel) as executor: + futures = [] + for question in questions: + future = executor.submit( + get_answer, + question, + args.model, + args.num_choices, + args.max_tokens, + answer_file, + ) + futures.append(future) + + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + future.result() + + reorg_answer_file(answer_file) diff --git a/fastchat/llm_judge/gen_judgment.py b/fastchat/llm_judge/gen_judgment.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c70b29545782f778ef349bf567a2e28b938428 --- /dev/null +++ b/fastchat/llm_judge/gen_judgment.py @@ -0,0 +1,322 @@ +""" +Usage: +python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all] +""" +import argparse +from concurrent.futures import ThreadPoolExecutor +import json + +import numpy as np +from tqdm import tqdm + +from fastchat.llm_judge.common import ( + load_questions, + load_model_answers, + load_judge_prompts, + check_data, + play_a_match_pair, + play_a_match_single, + get_model_list, + Judge, + MatchPair, + MatchSingle, + NEED_REF_CATS, +) + + +def make_match( + questions, + models, + model_answers, + judge, + baseline_model, + ref_answers=None, + multi_turn=False, +): + matches = [] + for q in questions: + if multi_turn and len(q["turns"]) != 2: + continue + for i in range(len(models)): + q_id = q["question_id"] + m_1 = models[i] + m_2 = baseline_model + if m_1 == m_2: + continue + a_1 = model_answers[m_1][q_id] + a_2 = model_answers[baseline_model][q_id] + if ref_answers is not None: + ref = ref_answers[judge.model_name][q_id] + match = MatchPair( + dict(q), + m_1, + m_2, + a_1, + a_2, + judge, + ref_answer=ref, + multi_turn=multi_turn, + ) + else: + match = MatchPair( + dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn + ) + matches.append(match) + return matches + + +def make_match_all_pairs( + questions, + models, + model_answers, + judge, + baseline_model=None, + ref_answers=None, + multi_turn=False, +): + matches = [] + for q in questions: + if multi_turn and len(q["turns"]) != 2: + continue + for i in range(len(models)): + for j in range(i + 1, len(models)): + q_id = q["question_id"] + m_1 = models[i] + m_2 = models[j] + a_1 = model_answers[m_1][q_id] + a_2 = model_answers[m_2][q_id] + if ref_answers is not None: + ref = ref_answers[judge.model_name][q_id] + match = MatchPair( + dict(q), + m_1, + m_2, + a_1, + a_2, + judge, + ref_answer=ref, + multi_turn=multi_turn, + ) + else: + match = MatchPair( + dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn + ) + matches.append(match) + return matches + + +def make_match_single( + questions, + models, + model_answers, + judge, + baseline_model=None, + ref_answers=None, + multi_turn=False, +): + matches = [] + for q in questions: + if multi_turn and len(q["turns"]) != 2: + continue + for i in range(len(models)): + q_id = q["question_id"] + m = models[i] + a = model_answers[m][q_id] + if ref_answers is not None: + ref = ref_answers[judge.model_name][q_id] + matches.append( + MatchSingle( + dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn + ) + ) + else: + matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn)) + return matches + + +def make_judge_pairwise(judge_model, judge_prompts): + judges = {} + judges["default"] = Judge(judge_model, judge_prompts["pair-v2"]) + judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True) + judges["default-mt"] = Judge( + judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True + ) + judges["math-mt"] = Judge( + judge_model, + judge_prompts["pair-math-v1-multi-turn"], + ref_based=True, + multi_turn=True, + ) + return judges + + +def make_judge_single(judge_model, judge_prompts): + judges = {} + judges["default"] = Judge(judge_model, judge_prompts["single-v1"]) + judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True) + judges["default-mt"] = Judge( + judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True + ) + judges["math-mt"] = Judge( + judge_model, + judge_prompts["single-math-v1-multi-turn"], + ref_based=True, + multi_turn=True, + ) + return judges + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bench-name", + type=str, + default="mt_bench", + help="The name of the benchmark question set.", + ) + parser.add_argument( + "--judge-file", + type=str, + default="data/judge_prompts.jsonl", + help="The file of judge prompts.", + ) + parser.add_argument("--judge-model", type=str, default="gpt-4") + parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") + parser.add_argument( + "--mode", + type=str, + default="single", + choices=["pairwise-baseline", "pairwise-all", "single"], + help=( + "Evaluation mode. " + "`pairwise-baseline` runs pairwise comparision against a baseline. " + "`pairwise-all` runs pairwise comparision between all pairs. " + "`single` runs single answer grading." + ), + ) + parser.add_argument( + "--model-list", + type=str, + nargs="+", + default=None, + help="A list of models to be evaluated", + ) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument( + "--first-n", type=int, help="A debug option. Only run the first `n` judgments." + ) + args = parser.parse_args() + + question_file = f"data/{args.bench_name}/question.jsonl" + answer_dir = f"data/{args.bench_name}/model_answer" + ref_answer_dir = f"data/{args.bench_name}/reference_answer" + + # Load questions + questions = load_questions(question_file, None, None) + + # Load answers + model_answers = load_model_answers(answer_dir) + ref_answers = load_model_answers(ref_answer_dir) + + # Load judge + judge_prompts = load_judge_prompts(args.judge_file) + + if args.first_n: + questions = questions[: args.first_n] + + if args.model_list is None: + models = get_model_list(answer_dir) + else: + models = args.model_list + + if args.mode == "single": + judges = make_judge_single(args.judge_model, judge_prompts) + play_a_match_func = play_a_match_single + output_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" + ) + make_match_func = make_match_single + baseline_model = None + else: + judges = make_judge_pairwise(args.judge_model, judge_prompts) + play_a_match_func = play_a_match_pair + output_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" + ) + if args.mode == "pairwise-all": + make_match_func = make_match_all_pairs + baseline_model = None + else: + make_match_func = make_match + baseline_model = args.baseline_model + + check_data(questions, model_answers, ref_answers, models, judges) + + question_math = [q for q in questions if q["category"] in NEED_REF_CATS] + question_default = [q for q in questions if q["category"] not in NEED_REF_CATS] + + # Make matches + matches = [] + matches += make_match_func( + question_default, models, model_answers, judges["default"], baseline_model + ) + matches += make_match_func( + question_math, + models, + model_answers, + judges["math"], + baseline_model, + ref_answers, + ) + matches += make_match_func( + question_default, + models, + model_answers, + judges["default-mt"], + baseline_model, + multi_turn=True, + ) + matches += make_match_func( + question_math, + models, + model_answers, + judges["math-mt"], + baseline_model, + ref_answers, + multi_turn=True, + ) + + match_stat = {} + match_stat["bench_name"] = args.bench_name + match_stat["mode"] = args.mode + match_stat["judge"] = args.judge_model + match_stat["baseline"] = baseline_model + match_stat["model_list"] = models + match_stat["total_num_questions"] = len(questions) + match_stat["total_num_matches"] = len(matches) + match_stat["output_path"] = output_file + + # Show match stats and prompt enter to continue + print("Stats:") + print(json.dumps(match_stat, indent=4)) + input("Press Enter to confirm...") + + # Play matches + if args.parallel == 1: + for match in tqdm(matches): + play_a_match_func(match, output_file=output_file) + else: + + def play_a_match_wrapper(match): + play_a_match_func(match, output_file=output_file) + + np.random.seed(0) + np.random.shuffle(matches) + + with ThreadPoolExecutor(args.parallel) as executor: + for match in tqdm( + executor.map(play_a_match_wrapper, matches), total=len(matches) + ): + pass diff --git a/fastchat/llm_judge/gen_model_answer.py b/fastchat/llm_judge/gen_model_answer.py new file mode 100644 index 0000000000000000000000000000000000000000..44166fb08852364439db93b2fdcbba03d7331c2a --- /dev/null +++ b/fastchat/llm_judge/gen_model_answer.py @@ -0,0 +1,304 @@ +"""Generate answers with local models. + +Usage: +python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 +""" +import argparse +import json +import os +import random +import time + +import shortuuid +import torch +from tqdm import tqdm + +from fastchat.llm_judge.common import load_questions, temperature_config +from fastchat.model import load_model, get_conversation_template +from fastchat.utils import str_to_torch_dtype + + +def run_eval( + model_path, + model_id, + question_file, + question_begin, + question_end, + answer_file, + max_new_token, + num_choices, + num_gpus_per_model, + num_gpus_total, + max_gpu_memory, + dtype, + revision, +): + questions = load_questions(question_file, question_begin, question_end) + # random shuffle the questions to balance the loading + random.shuffle(questions) + + # Split the question file into `num_gpus` files + assert num_gpus_total % num_gpus_per_model == 0 + use_ray = num_gpus_total // num_gpus_per_model > 1 + + if use_ray: + get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( + get_model_answers + ).remote + else: + get_answers_func = get_model_answers + + chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) + ans_handles = [] + for i in range(0, len(questions), chunk_size): + ans_handles.append( + get_answers_func( + model_path, + model_id, + questions[i : i + chunk_size], + answer_file, + max_new_token, + num_choices, + num_gpus_per_model, + max_gpu_memory, + dtype=dtype, + revision=revision, + ) + ) + + if use_ray: + ray.get(ans_handles) + + +@torch.inference_mode() +def get_model_answers( + model_path, + model_id, + questions, + answer_file, + max_new_token, + num_choices, + num_gpus_per_model, + max_gpu_memory, + dtype, + revision, +): + model, tokenizer = load_model( + model_path, + revision=revision, + device="cuda", + num_gpus=num_gpus_per_model, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=False, + cpu_offloading=False, + debug=False, + ) + + for question in tqdm(questions): + if question["category"] in temperature_config: + temperature = temperature_config[question["category"]] + else: + temperature = 0.7 + + choices = [] + for i in range(num_choices): + torch.manual_seed(i) + conv = get_conversation_template(model_id) + turns = [] + for j in range(len(question["turns"])): + qs = question["turns"][j] + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer([prompt]).input_ids + + if temperature < 1e-4: + do_sample = False + else: + do_sample = True + + # some models may error out when generating long outputs + try: + output_ids = model.generate( + torch.as_tensor(input_ids).cuda(), + do_sample=do_sample, + temperature=temperature, + max_new_tokens=max_new_token, + ) + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(input_ids[0]) :] + + # be consistent with the template's stop_token_ids + if conv.stop_token_ids: + stop_token_ids_index = [ + i + for i, id in enumerate(output_ids) + if id in conv.stop_token_ids + ] + if len(stop_token_ids_index) > 0: + output_ids = output_ids[: stop_token_ids_index[0]] + + output = tokenizer.decode( + output_ids, + spaces_between_special_tokens=False, + ) + if conv.stop_str and isinstance(conv.stop_str, list): + stop_str_indices = sorted( + [ + output.find(stop_str) + for stop_str in conv.stop_str + if output.find(stop_str) > 0 + ] + ) + if len(stop_str_indices) > 0: + output = output[: stop_str_indices[0]] + elif conv.stop_str and output.find(conv.stop_str) > 0: + output = output[: output.find(conv.stop_str)] + + for special_token in tokenizer.special_tokens_map.values(): + if isinstance(special_token, list): + for special_tok in special_token: + output = output.replace(special_tok, "") + else: + output = output.replace(special_token, "") + output = output.strip() + + if conv.name == "xgen" and output.startswith("Assistant:"): + output = output.replace("Assistant:", "", 1).strip() + except RuntimeError as e: + print("ERROR question ID: ", question["question_id"]) + output = "ERROR" + + conv.update_last_message(output) + turns.append(output) + + choices.append({"index": i, "turns": turns}) + + # Dump answers + os.makedirs(os.path.dirname(answer_file), exist_ok=True) + with open(os.path.expanduser(answer_file), "a") as fout: + ans_json = { + "question_id": question["question_id"], + "answer_id": shortuuid.uuid(), + "model_id": model_id, + "choices": choices, + "tstamp": time.time(), + } + fout.write(json.dumps(ans_json) + "\n") + + +def reorg_answer_file(answer_file): + """Sort by question id and de-duplication""" + answers = {} + with open(answer_file, "r") as fin: + for l in fin: + qid = json.loads(l)["question_id"] + answers[qid] = l + + qids = sorted(list(answers.keys())) + with open(answer_file, "w") as fout: + for qid in qids: + fout.write(answers[qid]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", + type=str, + required=True, + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-id", type=str, required=True, help="A custom name for the model." + ) + parser.add_argument( + "--bench-name", + type=str, + default="mt_bench", + help="The name of the benchmark question set.", + ) + parser.add_argument( + "--question-begin", + type=int, + help="A debug option. The begin index of questions.", + ) + parser.add_argument( + "--question-end", type=int, help="A debug option. The end index of questions." + ) + parser.add_argument("--answer-file", type=str, help="The output answer file.") + parser.add_argument( + "--max-new-token", + type=int, + default=1024, + help="The maximum number of new generated tokens.", + ) + parser.add_argument( + "--num-choices", + type=int, + default=1, + help="How many completion choices to generate.", + ) + parser.add_argument( + "--num-gpus-per-model", + type=int, + default=1, + help="The number of GPUs per model.", + ) + parser.add_argument( + "--num-gpus-total", type=int, default=1, help="The total number of GPUs." + ) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="Maxmum GPU memory used for model weights per GPU.", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", + default=None, + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="The model revision to load.", + ) + + args = parser.parse_args() + + if args.num_gpus_total // args.num_gpus_per_model > 1: + import ray + + ray.init() + + question_file = f"data/{args.bench_name}/question.jsonl" + if args.answer_file: + answer_file = args.answer_file + else: + answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" + + print(f"Output to {answer_file}") + + run_eval( + model_path=args.model_path, + model_id=args.model_id, + question_file=question_file, + question_begin=args.question_begin, + question_end=args.question_end, + answer_file=answer_file, + max_new_token=args.max_new_token, + num_choices=args.num_choices, + num_gpus_per_model=args.num_gpus_per_model, + num_gpus_total=args.num_gpus_total, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + revision=args.revision, + ) + + reorg_answer_file(answer_file) diff --git a/fastchat/llm_judge/qa_browser.py b/fastchat/llm_judge/qa_browser.py new file mode 100644 index 0000000000000000000000000000000000000000..1107756dbf9baccfe1fc5b0bdbdc9796f6f61807 --- /dev/null +++ b/fastchat/llm_judge/qa_browser.py @@ -0,0 +1,420 @@ +""" +Usage: +python3 qa_browser.py --share +""" + +import argparse +from collections import defaultdict +import re + +import gradio as gr + +from fastchat.llm_judge.common import ( + load_questions, + load_model_answers, + load_single_model_judgments, + load_pairwise_model_judgments, + resolve_single_judgment_dict, + resolve_pairwise_judgment_dict, + get_single_judge_explanation, + get_pairwise_judge_explanation, +) + + +questions = [] +model_answers = {} + +model_judgments_normal_single = {} +model_judgments_math_single = {} + +model_judgments_normal_pairwise = {} +model_judgments_math_pairwise = {} + +question_selector_map = {} +category_selector_map = defaultdict(list) + + +def display_question(category_selector, request: gr.Request): + choices = category_selector_map[category_selector] + return gr.Dropdown( + value=choices[0], + choices=choices, + ) + + +def display_pairwise_answer( + question_selector, model_selector1, model_selector2, request: gr.Request +): + q = question_selector_map[question_selector] + qid = q["question_id"] + + ans1 = model_answers[model_selector1][qid] + ans2 = model_answers[model_selector2][qid] + + chat_mds = pairwise_to_gradio_chat_mds(q, ans1, ans2) + gamekey = (qid, model_selector1, model_selector2) + + judgment_dict = resolve_pairwise_judgment_dict( + q, + model_judgments_normal_pairwise, + model_judgments_math_pairwise, + multi_turn=False, + ) + + explanation = ( + "##### Model Judgment (first turn)\n" + + get_pairwise_judge_explanation(gamekey, judgment_dict) + ) + + judgment_dict_turn2 = resolve_pairwise_judgment_dict( + q, + model_judgments_normal_pairwise, + model_judgments_math_pairwise, + multi_turn=True, + ) + + explanation_turn2 = ( + "##### Model Judgment (second turn)\n" + + get_pairwise_judge_explanation(gamekey, judgment_dict_turn2) + ) + + return chat_mds + [explanation] + [explanation_turn2] + + +def display_single_answer(question_selector, model_selector1, request: gr.Request): + q = question_selector_map[question_selector] + qid = q["question_id"] + + ans1 = model_answers[model_selector1][qid] + + chat_mds = single_to_gradio_chat_mds(q, ans1) + gamekey = (qid, model_selector1) + + judgment_dict = resolve_single_judgment_dict( + q, model_judgments_normal_single, model_judgments_math_single, multi_turn=False + ) + + explanation = "##### Model Judgment (first turn)\n" + get_single_judge_explanation( + gamekey, judgment_dict + ) + + judgment_dict_turn2 = resolve_single_judgment_dict( + q, model_judgments_normal_single, model_judgments_math_single, multi_turn=True + ) + + explanation_turn2 = ( + "##### Model Judgment (second turn)\n" + + get_single_judge_explanation(gamekey, judgment_dict_turn2) + ) + + return chat_mds + [explanation] + [explanation_turn2] + + +newline_pattern1 = re.compile("\n\n(\d+\. )") +newline_pattern2 = re.compile("\n\n(- )") + + +def post_process_answer(x): + """Fix Markdown rendering problems.""" + x = x.replace("\u2022", "- ") + x = re.sub(newline_pattern1, "\n\g<1>", x) + x = re.sub(newline_pattern2, "\n\g<1>", x) + return x + + +def pairwise_to_gradio_chat_mds(question, ans_a, ans_b, turn=None): + end = len(question["turns"]) if turn is None else turn + 1 + + mds = ["", "", "", "", "", "", ""] + for i in range(end): + base = i * 3 + if i == 0: + mds[base + 0] = "##### User\n" + question["turns"][i] + else: + mds[base + 0] = "##### User's follow-up question \n" + question["turns"][i] + mds[base + 1] = "##### Assistant A\n" + post_process_answer( + ans_a["choices"][0]["turns"][i].strip() + ) + mds[base + 2] = "##### Assistant B\n" + post_process_answer( + ans_b["choices"][0]["turns"][i].strip() + ) + + ref = question.get("reference", ["", ""]) + + ref_md = "" + if turn is None: + if ref[0] != "" or ref[1] != "": + mds[6] = f"##### Reference Solution\nQ1. {ref[0]}\nQ2. {ref[1]}" + else: + x = ref[turn] if turn < len(ref) else "" + if x: + mds[6] = f"##### Reference Solution\n{ref[turn]}" + else: + mds[6] = "" + return mds + + +def single_to_gradio_chat_mds(question, ans, turn=None): + end = len(question["turns"]) if turn is None else turn + 1 + + mds = ["", "", "", "", ""] + for i in range(end): + base = i * 2 + if i == 0: + mds[base + 0] = "##### User\n" + question["turns"][i] + else: + mds[base + 0] = "##### User's follow-up question \n" + question["turns"][i] + mds[base + 1] = "##### Assistant A\n" + post_process_answer( + ans["choices"][0]["turns"][i].strip() + ) + + ref = question.get("reference", ["", ""]) + + ref_md = "" + if turn is None: + if ref[0] != "" or ref[1] != "": + mds[4] = f"##### Reference Solution\nQ1. {ref[0]}\nQ2. {ref[1]}" + else: + x = ref[turn] if turn < len(ref) else "" + if x: + mds[4] = f"##### Reference Solution\n{ref[turn]}" + else: + mds[4] = "" + return mds + + +def build_question_selector_map(): + global question_selector_map, category_selector_map + + # Build question selector map + for q in questions: + preview = f"{q['question_id']}: " + q["turns"][0][:128] + "..." + question_selector_map[preview] = q + category_selector_map[q["category"]].append(preview) + + +def build_pairwise_browser_tab(): + global question_selector_map, category_selector_map + + models = list(model_answers.keys()) + num_sides = 2 + num_turns = 2 + side_names = ["A", "B"] + + question_selector_choices = list(question_selector_map.keys()) + category_selector_choices = list(category_selector_map.keys()) + + # Selectors + with gr.Row(): + with gr.Column(scale=1, min_width=200): + category_selector = gr.Dropdown( + choices=category_selector_choices, label="Category", container=False + ) + with gr.Column(scale=100): + question_selector = gr.Dropdown( + choices=question_selector_choices, label="Question", container=False + ) + + model_selectors = [None] * num_sides + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + if i == 0: + value = models[0] + else: + value = "gpt-3.5-turbo" + model_selectors[i] = gr.Dropdown( + choices=models, + value=value, + label=f"Model {side_names[i]}", + container=False, + ) + + # Conversation + chat_mds = [] + for i in range(num_turns): + chat_mds.append(gr.Markdown(elem_id=f"user_question_{i+1}")) + with gr.Row(): + for j in range(num_sides): + with gr.Column(scale=100): + chat_mds.append(gr.Markdown()) + + if j == 0: + with gr.Column(scale=1, min_width=8): + gr.Markdown() + reference = gr.Markdown(elem_id=f"reference") + chat_mds.append(reference) + + model_explanation = gr.Markdown(elem_id="model_explanation") + model_explanation2 = gr.Markdown(elem_id="model_explanation") + + # Callbacks + category_selector.change(display_question, [category_selector], [question_selector]) + question_selector.change( + display_pairwise_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + for i in range(num_sides): + model_selectors[i].change( + display_pairwise_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + return (category_selector,) + + +def build_single_answer_browser_tab(): + global question_selector_map, category_selector_map + + models = list(model_answers.keys()) + num_sides = 1 + num_turns = 2 + side_names = ["A"] + + question_selector_choices = list(question_selector_map.keys()) + category_selector_choices = list(category_selector_map.keys()) + + # Selectors + with gr.Row(): + with gr.Column(scale=1, min_width=200): + category_selector = gr.Dropdown( + choices=category_selector_choices, label="Category", container=False + ) + with gr.Column(scale=100): + question_selector = gr.Dropdown( + choices=question_selector_choices, label="Question", container=False + ) + + model_selectors = [None] * num_sides + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + label=f"Model {side_names[i]}", + container=False, + ) + + # Conversation + chat_mds = [] + for i in range(num_turns): + chat_mds.append(gr.Markdown(elem_id=f"user_question_{i+1}")) + with gr.Row(): + for j in range(num_sides): + with gr.Column(scale=100): + chat_mds.append(gr.Markdown()) + + if j == 0: + with gr.Column(scale=1, min_width=8): + gr.Markdown() + + reference = gr.Markdown(elem_id=f"reference") + chat_mds.append(reference) + + model_explanation = gr.Markdown(elem_id="model_explanation") + model_explanation2 = gr.Markdown(elem_id="model_explanation") + + # Callbacks + category_selector.change(display_question, [category_selector], [question_selector]) + question_selector.change( + display_single_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + for i in range(num_sides): + model_selectors[i].change( + display_single_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + return (category_selector,) + + +block_css = """ +#user_question_1 { + background-color: #DEEBF7; +} +#user_question_2 { + background-color: #E2F0D9; +} +#reference { + background-color: #FFF2CC; +} +#model_explanation { + background-color: #FBE5D6; +} +""" + + +def load_demo(): + dropdown_update = gr.Dropdown.update(value=list(category_selector_map.keys())[0]) + return dropdown_update, dropdown_update + + +def build_demo(): + build_question_selector_map() + + with gr.Blocks( + title="MT-Bench Browser", + theme=gr.themes.Base(text_size=gr.themes.sizes.text_lg), + css=block_css, + ) as demo: + gr.Markdown( + """ +# MT-Bench Browser +The code to generate answers and judgments is at [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). +""" + ) + with gr.Tab("Single Answer Grading"): + (category_selector,) = build_single_answer_browser_tab() + with gr.Tab("Pairwise Comparison"): + (category_selector2,) = build_pairwise_browser_tab() + demo.load(load_demo, [], [category_selector, category_selector2]) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--bench-name", type=str, default="mt_bench") + args = parser.parse_args() + print(args) + + question_file = f"data/{args.bench_name}/question.jsonl" + answer_dir = f"data/{args.bench_name}/model_answer" + pairwise_model_judgment_file = ( + f"data/{args.bench_name}/model_judgment/gpt-4_pair.jsonl" + ) + single_model_judgment_file = ( + f"data/{args.bench_name}/model_judgment/gpt-4_single.jsonl" + ) + + # Load questions + questions = load_questions(question_file, None, None) + + # Load answers + model_answers = load_model_answers(answer_dir) + + # Load model judgments + model_judgments_normal_single = ( + model_judgments_math_single + ) = load_single_model_judgments(single_model_judgment_file) + model_judgments_normal_pairwise = ( + model_judgments_math_pairwise + ) = load_pairwise_model_judgments(pairwise_model_judgment_file) + + demo = build_demo() + demo.queue( + default_concurrency_limit=10, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/fastchat/llm_judge/show_result.py b/fastchat/llm_judge/show_result.py new file mode 100644 index 0000000000000000000000000000000000000000..f20801b11a2e71c196d0c574c3c10aa28968c1fd --- /dev/null +++ b/fastchat/llm_judge/show_result.py @@ -0,0 +1,130 @@ +""" +Usage: +python3 show_result.py --mode [single|pairwise-baseline|pairwise-all] +""" +import argparse +import pandas as pd + + +def display_result_single(args): + if args.input_file is None: + input_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" + ) + else: + input_file = args.input_file + + print(f"Input file: {input_file}") + df_all = pd.read_json(input_file, lines=True) + df = df_all[["model", "score", "turn"]] + df = df[df["score"] != -1] + + if args.model_list is not None: + df = df[df["model"].isin(args.model_list)] + + print("\n########## First turn ##########") + df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean() + print(df_1.sort_values(by="score", ascending=False)) + + if args.bench_name == "mt_bench": + print("\n########## Second turn ##########") + df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean() + print(df_2.sort_values(by="score", ascending=False)) + + print("\n########## Average ##########") + df_3 = df[["model", "score"]].groupby(["model"]).mean() + print(df_3.sort_values(by="score", ascending=False)) + + +def display_result_pairwise(args): + if args.input_file is None: + input_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" + ) + else: + input_file = args.input_file + + print(f"Input file: {input_file}") + df_all = pd.read_json(input_file, lines=True) + df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")] + + model_list = ( + df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist() + ) + model_list = list(set(model_list)) + + list_res = [] + # traverse df row by row + for index, row in df_all.iterrows(): + if args.model_list is not None and row["model_1"] not in args.model_list: + continue + if args.baseline_model is not None: + if args.baseline_model not in [row["model_1"], row["model_2"]]: + continue + if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]: + list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1}) + list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1}) + else: + if row["g1_winner"] == "model_1": + winner = row["model_1"] + loser = row["model_2"] + else: + winner = row["model_2"] + loser = row["model_1"] + list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0}) + list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0}) + + df = pd.DataFrame(list_res) + df = df.groupby(["model"]).sum() + + # remove baseline model + if args.baseline_model is not None: + df = df[df.index != args.baseline_model] + # add win rate + df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"]) + df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"]) + # each tie counts as 0.5 win + 0.5 loss + df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / ( + df["win"] + df["loss"] + df["tie"] + ) + # print(df.sort_values(by="win_rate", ascending=False)) + # print(df.sort_values(by="loss_rate", ascending=True)) + print(df.sort_values(by="win_rate_adjusted", ascending=False)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--bench-name", type=str, default="mt_bench") + parser.add_argument("--input-file", type=str) + parser.add_argument("--judge-model", type=str, default="gpt-4") + parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") + parser.add_argument( + "--model-list", + type=str, + nargs="+", + default=None, + help="A list of models to be evaluated", + ) + parser.add_argument( + "--mode", + type=str, + default="single", + choices=["pairwise-baseline", "pairwise-all", "single"], + help=( + "Evaluation mode. " + "`pairwise-baseline` runs pairwise comparision against a baseline. " + "`pairwise-all` runs pairwise comparision between all pairs. " + "`single` runs single answer grading." + ), + ) + args = parser.parse_args() + + if args.mode == "single": + display_result_func = display_result_single + else: + if args.mode == "pairwise-all": + args.baseline_model = None + display_result_func = display_result_pairwise + + print(f"Mode: {args.mode}") + display_result_func(args) diff --git a/fastchat/model/__init__.py b/fastchat/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29767dce6ae41b72ecabfed477531684a4241d55 --- /dev/null +++ b/fastchat/model/__init__.py @@ -0,0 +1,5 @@ +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + add_model_args, +) diff --git a/fastchat/model/apply_delta.py b/fastchat/model/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1c06d48aa1125113f7a864ec26d5c9368a91f5 --- /dev/null +++ b/fastchat/model/apply_delta.py @@ -0,0 +1,165 @@ +""" +Apply the delta weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 +""" +import argparse +import gc +import glob +import json +import os +import shutil +import tempfile + +from huggingface_hub import snapshot_download +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +GB = 1 << 30 + + +def split_files(model_path, tmp_path, split_size): + if not os.path.exists(model_path): + model_path = snapshot_download(repo_id=model_path) + if not os.path.exists(tmp_path): + os.makedirs(tmp_path) + + file_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(file_pattern) + + part = 0 + try: + for file_path in tqdm(files): + state_dict = torch.load(file_path) + new_state_dict = {} + + current_size = 0 + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + if current_size + param_size > split_size: + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + current_size = 0 + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + + new_state_dict[name] = param + current_size += param_size + + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + except Exception as e: + print(f"An error occurred during split_files: {e}") + shutil.rmtree(tmp_path) + raise + + +def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta_config = AutoConfig.from_pretrained(delta_path) + + if os.path.exists(target_model_path): + shutil.rmtree(target_model_path) + os.makedirs(target_model_path) + + split_size = 4 * GB + + with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: + print(f"Split files for the base model to {tmp_base_path}") + split_files(base_model_path, tmp_base_path, split_size) + print(f"Split files for the delta weights to {tmp_delta_path}") + split_files(delta_path, tmp_delta_path, split_size) + + base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") + base_files = glob.glob(base_pattern) + delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") + delta_files = glob.glob(delta_pattern) + delta_state_dict = torch.load(delta_files[0]) + + print("Applying the delta") + weight_map = {} + total_size = 0 + + for i, base_file in tqdm(enumerate(base_files)): + state_dict = torch.load(base_file) + file_name = f"pytorch_model-{i}.bin" + for name, param in state_dict.items(): + if name not in delta_state_dict: + for delta_file in delta_files: + delta_state_dict = torch.load(delta_file) + gc.collect() + if name in delta_state_dict: + break + + state_dict[name] += delta_state_dict[name] + weight_map[name] = file_name + total_size += param.numel() * param.element_size() + gc.collect() + torch.save(state_dict, os.path.join(target_model_path, file_name)) + + with open( + os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" + ) as f: + json.dump( + {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f + ) + + print(f"Saving the target model to {target_model_path}") + delta_tokenizer.save_pretrained(target_model_path) + delta_config.save_pretrained(target_model_path) + + +def apply_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the delta weights from {delta_path}") + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta = AutoModelForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print("Applying the delta") + for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): + assert name in delta.state_dict() + param.data += delta.state_dict()[name] + + print(f"Saving the target model to {target_model_path}") + base.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument( + "--low-cpu-mem", + action="store_true", + help="Lower the cpu memory usage. This will split large files and use " + "disk as swap to reduce the memory usage below 10GB.", + ) + args = parser.parse_args() + + if args.low_cpu_mem: + apply_delta_low_cpu_mem( + args.base_model_path, args.target_model_path, args.delta_path + ) + else: + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/fastchat/model/apply_lora.py b/fastchat/model/apply_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..01263dcc71535e275c7509af96d10eac3b79926b --- /dev/null +++ b/fastchat/model/apply_lora.py @@ -0,0 +1,48 @@ +""" +Apply the LoRA weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B + +Dependency: +pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b +""" +import argparse + +import torch +from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def apply_lora(base_model_path, target_model_path, lora_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) + + print(f"Loading the LoRA adapter from {lora_path}") + + lora_model = PeftModel.from_pretrained( + base, + lora_path, + # torch_dtype=torch.float16 + ) + + print("Applying the LoRA") + model = lora_model.merge_and_unload() + + print(f"Saving the target model to {target_model_path}") + model.save_pretrained(target_model_path) + base_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--lora-path", type=str, required=True) + + args = parser.parse_args() + + apply_lora(args.base_model_path, args.target_model_path, args.lora_path) diff --git a/fastchat/model/compression.py b/fastchat/model/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..7329cfe0c5771c4b71d37e3c6b1a31aa95e79c66 --- /dev/null +++ b/fastchat/model/compression.py @@ -0,0 +1,312 @@ +import dataclasses +import gc +import glob +import os + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from huggingface_hub import snapshot_download +import torch +from torch import Tensor +from torch.nn import functional as F +import torch.nn as nn +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + AutoModel, + AutoModelForSeq2SeqLM, +) + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight=None, bias=None, device=None): + super().__init__() + if weight is None: + self.weight = None + elif isinstance(weight, Tensor): + self.weight = compress(weight.data.to(device), default_compression_config) + else: + self.weight = weight + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + if self.bias is None: + return F.linear(input.to(weight.dtype), weight) + return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr( + module, + attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device), + ) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def get_compressed_list(module, prefix=""): + compressed_list = [] + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + compressed_list.append(full_name) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + for each in get_compressed_list(child, child_prefix): + compressed_list.append(each) + return compressed_list + + +def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + setattr( + module, + attr_str, + CLinear( + compressed_state_dict[full_name], target_attr.bias, target_device + ), + ) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + apply_compressed_weight( + child, compressed_state_dict, target_device, child_prefix + ) + + +def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"): + # partially load model + # `use_fast=True`` is not supported for some models. + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=use_fast, revision=revision, trust_remote_code=True + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True + ) + with init_empty_weights(): + # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel + config = AutoConfig.from_pretrained( + model_path, + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + trust_remote_code=True, + revision=revision, + ) + # some models are loaded by AutoModel but not AutoModelForCausalLM, + # such as chatglm, chatglm2 + try: + # google/flan-* models are based on an AutoModelForSeq2SeqLM. + if "T5Config" in str(type(config)): + model = AutoModelForSeq2SeqLM.from_config( + config, trust_remote_code=True + ) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except NameError: + model = AutoModel.from_config(config, trust_remote_code=True) + linear_weights = get_compressed_list(model) + if os.path.exists(model_path): + # `model_path` is a local folder + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + else: + # `model_path` is a cached Hugging Face repo + # We don't necessarily need to download the model' repo again if there is a cache. + # So check the default huggingface cache first. + model_path_temp = os.path.join( + os.path.expanduser("~"), + ".cache/huggingface/hub", + "models--" + model_path.replace("/", "--"), + "snapshots/", + ) + downloaded = False + if os.path.exists(model_path_temp): + temp_last_dir = os.listdir(model_path_temp)[-1] + model_path_temp = os.path.join(model_path_temp, temp_last_dir) + base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin") + files = glob.glob(base_pattern) + if len(files) > 0: + downloaded = True + + if downloaded: + model_path = model_path_temp + else: + model_path = snapshot_download(model_path, revision=revision) + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + + files = glob.glob(base_pattern) + use_safetensors = False + if len(files) == 0: + base_pattern = os.path.join(model_path, "*.safetensors") + files = glob.glob(base_pattern) + use_safetensors = True + if len(files) == 0: + raise ValueError( + f"Cannot find any model weight files. " + f"Please check your (cached) weight path: {model_path}" + ) + + compressed_state_dict = {} + if use_safetensors: + from safetensors.torch import load_file + for filename in tqdm(files): + if use_safetensors: + tmp_state_dict = load_file(filename) + else: + tmp_state_dict = torch.load( + filename, map_location=lambda storage, loc: storage + ) + for name in tmp_state_dict: + if name in linear_weights: + tensor = tmp_state_dict[name].to(device, dtype=torch_dtype) + compressed_state_dict[name] = compress( + tensor, default_compression_config + ) + else: + compressed_state_dict[name] = tmp_state_dict[name].to( + device, dtype=torch_dtype + ) + tmp_state_dict[name] = None + tensor = None + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + for name in model.state_dict(): + if name not in linear_weights: + set_module_tensor_to_device( + model, name, device, value=compressed_state_dict[name] + ) + apply_compressed_weight(model, compressed_state_dict, device) + + if torch_dtype == torch.float16: + model.half() + model.to(device) + model.eval() + + return model, tokenizer + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2**num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/fastchat/model/convert_fp16.py b/fastchat/model/convert_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..efc40aa83bf3a85129a668387df86a41d925f13d --- /dev/null +++ b/fastchat/model/convert_fp16.py @@ -0,0 +1,26 @@ +""" +Usage: +python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder +""" +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + + +def convert_fp16(in_checkpoint, out_checkpoint): + tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.save_pretrained(out_checkpoint) + tokenizer.save_pretrained(out_checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-checkpoint", type=str, help="Path to the model") + parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") + args = parser.parse_args() + + convert_fp16(args.in_checkpoint, args.out_checkpoint) diff --git a/fastchat/model/llama_condense_monkey_patch.py b/fastchat/model/llama_condense_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..cb45a8bb6addf8a8506c847060e23dc65ae27995 --- /dev/null +++ b/fastchat/model/llama_condense_monkey_patch.py @@ -0,0 +1,71 @@ +# Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py + +from functools import partial + +import torch +import transformers +import transformers.models.llama.modeling_llama + + +class CondenseRotaryEmbedding(torch.nn.Module): + def __init__( + self, dim, ratio, max_position_embeddings=2048, base=10000, device=None + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.ratio = ratio + max_position_embeddings *= ratio + self.max_seq_len_cached = max_position_embeddings + # print(f"Monkey Patching condense ratio {ratio}") + t = ( + torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + / ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = ( + torch.arange( + self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype + ) + / self.ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def replace_llama_with_condense(ratio): + transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( + CondenseRotaryEmbedding, ratio=ratio + ) diff --git a/fastchat/model/make_delta.py b/fastchat/model/make_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..480ba8f1a2cb067d69df174ee7d00e5072ee5164 --- /dev/null +++ b/fastchat/model/make_delta.py @@ -0,0 +1,48 @@ +""" +Make the delta weights by subtracting base weights. + +Usage: +python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the target model from {target_model_path}") + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) + + print("Calculating the delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + assert name in base.state_dict() + param.data -= base.state_dict()[name] + + print(f"Saving the delta to {delta_path}") + if args.hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..135c108854a7450d98700ac18dc338e22f13f85c --- /dev/null +++ b/fastchat/model/model_adapter.py @@ -0,0 +1,2550 @@ +"""Model adapter registration.""" + +import math +import os +import re +import sys +from typing import Dict, List, Optional +import warnings + +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache as cache + +import psutil +import torch +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + LlamaTokenizer, + LlamaForCausalLM, + T5Tokenizer, +) + +from fastchat.constants import CPU_ISA +from fastchat.conversation import Conversation, get_conv_template +from fastchat.model.compression import load_compress_model +from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_yuan2 import generate_stream_yuan2 +from fastchat.model.model_exllama import generate_stream_exllama +from fastchat.model.model_xfastertransformer import generate_stream_xft +from fastchat.model.model_cllm import generate_stream_cllm + +from fastchat.model.monkey_patch_non_inplace import ( + replace_llama_attn_with_non_inplace_operations, +) +from fastchat.modules.awq import AWQConfig, load_awq_quantized +from fastchat.modules.exllama import ExllamaConfig, load_exllama_model +from fastchat.modules.xfastertransformer import load_xft_model, XftConfig +from fastchat.modules.gptq import GptqConfig, load_gptq_quantized +from fastchat.utils import get_gpu_memory + +# Check an environment variable to check if we should be sharing Peft model +# weights. When false we treat all Peft models as separate. +peft_share_base_weights = ( + os.environ.get("PEFT_SHARE_BASE_WEIGHTS", "false").lower() == "true" +) + +ANTHROPIC_MODEL_LIST = ( + "claude-1", + "claude-2", + "claude-2.0", + "claude-2.1", + "claude-3-haiku-20240307", + "claude-3-haiku-20240307-vertex", + "claude-3-sonnet-20240229", + "claude-3-sonnet-20240229-vertex", + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-instant-1", + "claude-instant-1.2", +) + +OPENAI_MODEL_LIST = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-turbo", + "gpt-4-1106-preview", + "gpt-4-0125-preview", + "gpt-4-turbo-browsing", + "gpt-4-turbo-2024-04-09", + "gpt2-chatbot", + "im-also-a-good-gpt2-chatbot", + "im-a-good-gpt2-chatbot", + "gpt-4o-2024-05-13", +) + + +class BaseModelAdapter: + """The base and the default model adapter.""" + + use_fast_tokenizer = True + + def match(self, model_path: str): + return True + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + revision=revision, + trust_remote_code=True, + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision, trust_remote_code=True + ) + try: + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + except NameError: + model = AutoModel.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def load_compress_model(self, model_path, device, torch_dtype, revision="main"): + return load_compress_model( + model_path, + device, + torch_dtype, + use_fast=self.use_fast_tokenizer, + revision=revision, + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +# A global registry for all model adapters +# TODO (lmzheng): make it a priority queue. +model_adapters: List[BaseModelAdapter] = [] + + +def register_model_adapter(cls): + """Register a model adapter.""" + model_adapters.append(cls()) + + +@cache +def get_model_adapter(model_path: str) -> BaseModelAdapter: + """Get a model adapter for a model_path.""" + model_path_basename = os.path.basename(os.path.normpath(model_path)) + + # Try the basename of model_path at first + for adapter in model_adapters: + if adapter.match(model_path_basename) and type(adapter) != BaseModelAdapter: + return adapter + + # Then try the full path + for adapter in model_adapters: + if adapter.match(model_path): + return adapter + + raise ValueError(f"No valid model adapter for {model_path}") + + +def raise_warning_for_incompatible_cpu_offloading_configuration( + device: str, load_8bit: bool, cpu_offloading: bool +): + if cpu_offloading: + if not load_8bit: + warnings.warn( + "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" + "Use '--load-8bit' to enable 8-bit-quantization\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if not "linux" in sys.platform: + warnings.warn( + "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if device != "cuda": + warnings.warn( + "CPU-offloading is only enabled when using CUDA-devices\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + return cpu_offloading + + +def load_model( + model_path: str, + device: str = "cuda", + num_gpus: int = 1, + max_gpu_memory: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + debug: bool = False, +): + """Load a model from Hugging Face.""" + import accelerate + + # get model adapter + adapter = get_model_adapter(model_path) + + # Handle device mapping + cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( + device, load_8bit, cpu_offloading + ) + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + if CPU_ISA in ["avx512_bf16", "amx"]: + try: + import intel_extension_for_pytorch as ipex + + kwargs = {"torch_dtype": torch.bfloat16} + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" + ) + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus != 1: + kwargs["device_map"] = "auto" + if max_gpu_memory is None: + kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = get_gpu_memory(num_gpus) + kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + import transformers + + version = tuple(int(v) for v in transformers.__version__.split(".")) + if version < (4, 35, 0): + # NOTE: Recent transformers library seems to fix the mps issue, also + # it has made some changes causing compatibility issues with our + # original patch. So we only apply the patch for older versions. + + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + elif device == "xpu": + kwargs = {"torch_dtype": torch.bfloat16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import intel_extension_for_pytorch as ipex + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, but is required for xpu inference." + ) + elif device == "npu": + kwargs = {"torch_dtype": torch.float16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import torch_npu + except ImportError: + warnings.warn("Ascend Extension for PyTorch is not installed.") + else: + raise ValueError(f"Invalid device: {device}") + + if cpu_offloading: + # raises an error on incompatible platforms + from transformers import BitsAndBytesConfig + + if "max_memory" in kwargs: + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit_fp32_cpu_offload=cpu_offloading + ) + kwargs["load_in_8bit"] = load_8bit + elif load_8bit: + if num_gpus != 1: + warnings.warn( + "8-bit quantization is not supported for multi-gpu inference." + ) + else: + model, tokenizer = adapter.load_compress_model( + model_path=model_path, + device=device, + torch_dtype=kwargs["torch_dtype"], + revision=revision, + ) + if debug: + print(model) + return model, tokenizer + elif awq_config and awq_config.wbits < 16: + assert ( + awq_config.wbits == 4 + ), "Currently we only support 4-bit inference for AWQ." + model, tokenizer = load_awq_quantized(model_path, awq_config, device) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif gptq_config and gptq_config.wbits < 16: + model, tokenizer = load_gptq_quantized(model_path, gptq_config) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=["LlamaDecoderLayer"], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif exllama_config: + model, tokenizer = load_exllama_model(model_path, exllama_config) + return model, tokenizer + elif xft_config: + model, tokenizer = load_xft_model(model_path, xft_config) + return model, tokenizer + kwargs["revision"] = revision + + if dtype is not None: # Overwrite dtype if it is provided in the arguments. + kwargs["torch_dtype"] = dtype + + if os.environ.get("FASTCHAT_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + try: + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model_path): + model_path = snapshot_download(model_id=model_path, revision=revision) + except ImportError as e: + warnings.warn( + "Use model from www.modelscope.cn need pip install modelscope" + ) + raise e + + # Load model + model, tokenizer = adapter.load_model(model_path, kwargs) + + if ( + device == "cpu" + and kwargs["torch_dtype"] is torch.bfloat16 + and CPU_ISA is not None + ): + model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) + + if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in ( + "mps", + "xpu", + "npu", + ): + model.to(device) + + if device == "xpu": + model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) + + if debug: + print(model) + + return model, tokenizer + + +def get_conversation_template(model_path: str) -> Conversation: + """Get the default conversation template.""" + adapter = get_model_adapter(model_path) + return adapter.get_default_conv_template(model_path) + + +def get_generate_stream_function(model: torch.nn.Module, model_path: str): + """Get the generate_stream function for inference.""" + from fastchat.serve.inference import generate_stream + + model_type = str(type(model)).lower() + is_peft = "peft" in model_type + is_chatglm = "chatglm" in model_type + is_falcon = "rwforcausallm" in model_type + is_codet5p = "codet5p" in model_type + is_exllama = "exllama" in model_type + is_xft = "xft" in model_type + is_yuan = "yuan" in model_type + is_cllm = "consistency-llm" in model_path.lower() + + if is_chatglm: + return generate_stream_chatglm + elif is_falcon: + return generate_stream_falcon + elif is_codet5p: + return generate_stream_codet5p + elif is_exllama: + return generate_stream_exllama + elif is_xft: + return generate_stream_xft + elif is_yuan: + return generate_stream_yuan2 + elif is_cllm: + return generate_stream_cllm + + elif peft_share_base_weights and is_peft: + # Return a curried stream function that loads the right adapter + # according to the model_name available in this context. This ensures + # the right weights are available. + @torch.inference_mode() + def generate_stream_peft( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, + ): + model.set_adapter(model_path) + base_model_type = str(type(model.base_model.model)) + is_chatglm = "chatglm" in base_model_type + is_falcon = "rwforcausallm" in base_model_type + is_codet5p = "codet5p" in base_model_type + is_exllama = "exllama" in base_model_type + is_xft = "xft" in base_model_type + is_yuan = "yuan" in base_model_type + is_cllm = "consistency-llm" in model_path.lower() + + generate_stream_function = generate_stream + if is_chatglm: + generate_stream_function = generate_stream_chatglm + elif is_falcon: + generate_stream_function = generate_stream_falcon + elif is_codet5p: + generate_stream_function = generate_stream_codet5p + elif is_exllama: + generate_stream_function = generate_stream_exllama + elif is_xft: + generate_stream_function = generate_stream_xft + elif is_yuan: + generate_stream_function = generate_stream_yuan2 + elif is_cllm: + generate_stream_function = generate_stream_cllm + for x in generate_stream_function( + model, + tokenizer, + params, + device, + context_len, + stream_interval, + judge_sent_end, + ): + yield x + + return generate_stream_peft + else: + return generate_stream + + +def add_model_args(parser): + parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", + ) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="A single GPU like 1 or multiple GPUs like 0,2", + ) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", + default=None, + ) + parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", + ) + parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Used for GPTQ. The path to the local GPTQ checkpoint.", + ) + parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="Used for GPTQ. #bits to use for quantization", + ) + parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Used for GPTQ. Groupsize to use for quantization; default uses full row.", + ) + parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Used for GPTQ. Whether to apply the activation order GPTQ heuristic", + ) + parser.add_argument( + "--awq-ckpt", + type=str, + default=None, + help="Used for AWQ. Load quantized model. The path to the local AWQ checkpoint.", + ) + parser.add_argument( + "--awq-wbits", + type=int, + default=16, + choices=[4, 16], + help="Used for AWQ. #bits to use for AWQ quantization", + ) + parser.add_argument( + "--awq-groupsize", + type=int, + default=-1, + help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.", + ) + parser.add_argument( + "--enable-exllama", + action="store_true", + help="Used for exllamabv2. Enable exllamaV2 inference framework.", + ) + parser.add_argument( + "--exllama-max-seq-len", + type=int, + default=4096, + help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.", + ) + parser.add_argument( + "--exllama-gpu-split", + type=str, + default=None, + help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7", + ) + parser.add_argument( + "--exllama-cache-8bit", + action="store_true", + help="Used for exllamabv2. Use 8-bit cache to save VRAM.", + ) + parser.add_argument( + "--enable-xft", + action="store_true", + help="Used for xFasterTransformer Enable xFasterTransformer inference framework.", + ) + parser.add_argument( + "--xft-max-seq-len", + type=int, + default=4096, + help="Used for xFasterTransformer. Max sequence length to use for xFasterTransformer framework; default 4096 sequence length.", + ) + parser.add_argument( + "--xft-dtype", + type=str, + choices=["fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"], + help="Override the default dtype. If not set, it will use bfloat16 for first token and float16 next tokens on CPU.", + default=None, + ) + + +def remove_parent_directory_name(model_path): + """Remove parent directory name.""" + if model_path[-1] == "/": + model_path = model_path[:-1] + return model_path.split("/")[-1] + + +peft_model_cache = {} + + +class PeftModelAdapter: + """Loads any "peft" model and it's base model.""" + + def match(self, model_path: str): + """Accepts any model path with "peft" in the name""" + if os.path.exists(os.path.join(model_path, "adapter_config.json")): + return True + return "peft" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + """Loads the base model then the (peft) adapter weights""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + base_model_path = config.base_model_name_or_path + if "peft" in base_model_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + + # Basic proof of concept for loading peft adapters that share the base + # weights. This is pretty messy because Peft re-writes the underlying + # base model and internally stores a map of adapter layers. + # So, to make this work we: + # 1. Cache the first peft model loaded for a given base models. + # 2. Call `load_model` for any follow on Peft models. + # 3. Make sure we load the adapters by the model_path. Why? This is + # what's accessible during inference time. + # 4. In get_generate_stream_function, make sure we load the right + # adapter before doing inference. This *should* be safe when calls + # are blocked the same semaphore. + if peft_share_base_weights: + if base_model_path in peft_model_cache: + model, tokenizer = peft_model_cache[base_model_path] + # Super important: make sure we use model_path as the + # `adapter_name`. + model.load_adapter(model_path, adapter_name=model_path) + else: + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + # Super important: make sure we use model_path as the + # `adapter_name`. + model = PeftModel.from_pretrained( + base_model, model_path, adapter_name=model_path + ) + peft_model_cache[base_model_path] = (model, tokenizer) + return model, tokenizer + + # In the normal case, load up the base model weights again. + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + model = PeftModel.from_pretrained(base_model, model_path) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + """Uses the conv template of the base model""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + if "peft" in config.base_model_name_or_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + base_model_path = config.base_model_name_or_path + base_adapter = get_model_adapter(base_model_path) + return base_adapter.get_default_conv_template(config.base_model_name_or_path) + + +class VicunaAdapter(BaseModelAdapter): + "Model adapter for Vicuna models (e.g., lmsys/vicuna-7b-v1.5)" "" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + self.raise_warning_for_old_weights(model) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "v0" in remove_parent_directory_name(model_path): + return get_conv_template("one_shot") + return get_conv_template("vicuna_v1.1") + + def raise_warning_for_old_weights(self, model): + if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000: + warnings.warn( + "\nYou are probably using the old Vicuna-v0 model, " + "which will generate unexpected results with the " + "current fastchat.\nYou can try one of the following methods:\n" + "1. Upgrade your weights to the new Vicuna-v1.3: https://github.com/lm-sys/FastChat#vicuna-weights.\n" + "2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template one_shot`\n" + "3. Downgrade fschat to fschat==0.1.10 (Not recommended).\n" + ) + + +class AiroborosAdapter(BaseModelAdapter): + """The model adapter for jondurbin/airoboros-*""" + + def match(self, model_path: str): + if re.search(r"airoboros|spicyboros", model_path, re.I): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "-3." in model_path or "-3p" in model_path: + return get_conv_template("airoboros_v3") + if "spicyboros" in model_path or re.search(r"-(2\.[2-9]+)", model_path): + return get_conv_template("airoboros_v2") + return get_conv_template("airoboros_v1") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + if "mpt" not in model_path.lower(): + return super().load_model(model_path, from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=True + ) + return model, tokenizer + + +class LongChatAdapter(BaseModelAdapter): + "Model adapter for LongChat models (e.g., lmsys/longchat-7b-16k)." + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "longchat" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + + # Apply monkey patch, TODO(Dacheng): Add flash attention support + config = AutoConfig.from_pretrained(model_path, revision=revision) + replace_llama_with_condense(config.rope_scaling["factor"]) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class GoogleT5Adapter(BaseModelAdapter): + """The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["flan-", "fastchat-t5", "codet5p"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + +class KoalaAdapter(BaseModelAdapter): + """The model adapter for Koala""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "koala" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("koala_v1") + + +class AlpacaAdapter(BaseModelAdapter): + """The model adapter for Alpaca""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "alpaca" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class ChatGLMAdapter(BaseModelAdapter): + """The model adapter for THUDM/chatglm-6b, THUDM/chatglm2-6b""" + + def match(self, model_path: str): + return "chatglm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + if "chatglm3" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + encode_special_tokens=True, + trust_remote_code=True, + revision=revision, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "chatglm2" in model_path.lower(): + return get_conv_template("chatglm2") + if "chatglm3" in model_path.lower(): + return get_conv_template("chatglm3") + return get_conv_template("chatglm") + + +class CodeGeexAdapter(BaseModelAdapter): + """The model adapter for THUDM/codegeex-6b, THUDM/codegeex2-6b""" + + def match(self, model_path: str): + return "codegeex" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codegeex") + + +class DollyV2Adapter(BaseModelAdapter): + """The model adapter for databricks/dolly-v2-12b""" + + def match(self, model_path: str): + return "dolly-v2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + # 50277 means "### End" + tokenizer.eos_token_id = 50277 + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolly_v2") + + +class OasstPythiaAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "oasst" in model_path and "pythia" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_pythia") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class OasstLLaMAAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-7-llama-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + if "openassistant-sft-7-llama-30b-hf" in model_path: + return True + return "oasst" in model_path and "pythia" not in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_llama") + + +class OpenChat35Adapter(BaseModelAdapter): + """The model adapter for OpenChat 3.5 (e.g. openchat/openchat_3.5)""" + + def match(self, model_path: str): + if "openchat" in model_path.lower() and "3.5" in model_path.lower(): + return True + elif "starling-lm" in model_path.lower(): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openchat_3.5") + + +class TenyxChatAdapter(BaseModelAdapter): + """The model adapter for TenyxChat (e.g. tenyx/TenyxChat-7B-v1)""" + + def match(self, model_path: str): + return "tenyxchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tenyxchat") + + +class PythiaAdapter(BaseModelAdapter): + """The model adapter for any EleutherAI/pythia model""" + + def match(self, model_path: str): + return "pythia" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class StableLMAdapter(BaseModelAdapter): + """The model adapter for StabilityAI/stablelm-tuned-alpha-7b""" + + def match(self, model_path: str): + return "stablelm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stablelm") + + +class MPTAdapter(BaseModelAdapter): + """The model adapter for MPT series (mosaicml/mpt-7b-chat, mosaicml/mpt-30b-chat)""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "mpt" in model_path and not "airoboros" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "mpt-7b-chat" in model_path: + return get_conv_template("mpt-7b-chat") + elif "mpt-30b-chat" in model_path: + return get_conv_template("mpt-30b-chat") + elif "mpt-30b-instruct" in model_path: + return get_conv_template("mpt-30b-instruct") + else: + print( + "Warning: Loading base MPT model with `zero_shot` conversation configuration. " + "If this is not desired, inspect model configurations and names." + ) + return get_conv_template("zero_shot") + + +class BaizeAdapter(BaseModelAdapter): + """The model adapter for project-baize/baize-v2-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "baize" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("baize") + + +class RwkvAdapter(BaseModelAdapter): + """The model adapter for BlinkDL/RWKV-4-Raven""" + + def match(self, model_path: str): + return "rwkv-4" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from fastchat.model.rwkv_model import RwkvModel + + model = RwkvModel(model_path) + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("rwkv") + + +class OpenBuddyAdapter(BaseModelAdapter): + """The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "openbuddy" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openbuddy") + + +class PhoenixAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/phoenix-inst-chat-7b""" + + def match(self, model_path: str): + return "phoenix" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phoenix") + + +class ReaLMAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/ReaLM-7b""" + + def match(self, model_path: str): + return "ReaLM" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("ReaLM-7b-v1") + + +class ChatGPTAdapter(BaseModelAdapter): + """The model adapter for ChatGPT""" + + def match(self, model_path: str): + return model_path in OPENAI_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "browsing" in model_path: + return get_conv_template("api_based_default") + if "gpt-4-turbo-2024-04-09" in model_path: + return get_conv_template("gpt-4-turbo-2024-04-09") + if "gpt2-chatbot" in model_path: + return get_conv_template("gpt-4-turbo-2024-04-09") + if "gpt-4o" in model_path: + return get_conv_template("gpt-4-turbo-2024-04-09") + return get_conv_template("chatgpt") + + +class AzureOpenAIAdapter(BaseModelAdapter): + """The model adapter for Azure OpenAI""" + + def match(self, model_path: str): + return model_path in ("azure-gpt-35-turbo", "azure-gpt-4") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class PplxAIAdapter(BaseModelAdapter): + """The model adapter for Perplexity AI""" + + def match(self, model_path: str): + return model_path in ( + "pplx-7b-online", + "pplx-70b-online", + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("pplxai") + + +class ClaudeAdapter(BaseModelAdapter): + """The model adapter for Claude""" + + def match(self, model_path: str): + return model_path in ANTHROPIC_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "claude-3-haiku" in model_path: + return get_conv_template("claude-3-haiku-20240307") + if "claude-3-sonnet" in model_path: + return get_conv_template("claude-3-sonnet-20240229") + if "claude-3-5-sonnet" in model_path: + return get_conv_template("claude-3-5-sonnet-20240620") + if "claude-3-opus" in model_path: + return get_conv_template("claude-3-opus-20240229") + return get_conv_template("claude") + + +class BardAdapter(BaseModelAdapter): + """The model adapter for Bard""" + + def match(self, model_path: str): + return model_path == "bard" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class PaLM2Adapter(BaseModelAdapter): + """The model adapter for PaLM2""" + + def match(self, model_path: str): + return model_path == "palm-2" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class GeminiAdapter(BaseModelAdapter): + """The model adapter for Gemini""" + + def match(self, model_path: str): + return "gemini" in model_path.lower() or "bard" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemini") + + +class GeminiDevAdapter(BaseModelAdapter): + """The model adapter for Gemini 1.5 Pro""" + + def match(self, model_path: str): + return "gemini-1.5-pro" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemini-dev") + + +class BiLLaAdapter(BaseModelAdapter): + """The model adapter for Neutralzz/BiLLa-7B-SFT""" + + def match(self, model_path: str): + return "billa" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("billa") + + +class RedPajamaINCITEAdapter(BaseModelAdapter): + """The model adapter for togethercomputer/RedPajama-INCITE-7B-Chat""" + + def match(self, model_path: str): + return "redpajama-incite" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("redpajama-incite") + + +class H2OGPTAdapter(BaseModelAdapter): + """The model adapter for h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "h2ogpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("h2ogpt") + + +class RobinAdapter(BaseModelAdapter): + """The model adapter for LMFlow/Full-Robin-7b-v2""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "robin" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Robin") + + +class SnoozyAdapter(BaseModelAdapter): + """The model adapter for nomic-ai/gpt4all-13b-snoozy""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + return "gpt4all" in model_path and "snoozy" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("snoozy") + + +class WizardLMAdapter(BaseModelAdapter): + """The model adapter for WizardLM/WizardLM-13B-V1.0""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardlm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "13b" in model_path or "30b" in model_path or "70b" in model_path: + return get_conv_template("vicuna_v1.1") + else: + # TODO: use the recommended template for 7B + # (https://huggingface.co/WizardLM/WizardLM-13B-V1.0) + return get_conv_template("one_shot") + + +class ManticoreAdapter(BaseModelAdapter): + """The model adapter for openaccess-ai-collective/manticore-13b-chat-pyg""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "manticore" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("manticore") + + +class GuanacoAdapter(BaseModelAdapter): + """The model adapter for timdettmers/guanaco-33b-merged""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "guanaco" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + # Fix a bug in tokenizer config + tokenizer.eos_token_id = model.config.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zero_shot") + + +class ChangGPTAdapter(BaseModelAdapter): + """The model adapter for lcw99/polyglot-ko-12.8b-chang-instruct-chat""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "polyglot" in model_path and "chang" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class CamelAdapter(BaseModelAdapter): + """The model adapter for camel-ai/CAMEL-13B-Combined-Data""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "camel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class TuluAdapter(BaseModelAdapter): + """The model adapter for allenai/tulu-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "tulu" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tulu") + + +class FalconAdapter(BaseModelAdapter): + """The model adapter for tiiuae/falcon-40b""" + + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" not in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # Strongly suggest using bf16, which is recommended by the author of Falcon + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + # In Falcon tokenizer config and special config there is not any pad token + # Setting `pad_token_id` to 9, which corresponds to special token '>>SUFFIX<<' + tokenizer.pad_token_id = 9 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon") + + +class FalconChatAdapter(BaseModelAdapter): + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon-chat") + + +class TigerBotAdapter(BaseModelAdapter): + """The model adapter for TigerResearch/tigerbot-7b-sft""" + + def match(self, model_path: str): + return "tigerbot" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tigerbot") + + +class BaichuanAdapter(BaseModelAdapter): + """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-7B)""" + + def match(self, model_path: str): + return "baichuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + # for Baichuan-13B-Chat + if "chat" in model_path.lower(): + if "baichuan2" in model_path.lower(): + return get_conv_template("baichuan2-chat") + return get_conv_template("baichuan-chat") + return get_conv_template("zero_shot") + + +class XGenAdapter(BaseModelAdapter): + """The model adapter for Salesforce/xgen-7b""" + + def match(self, model_path: str): + return "xgen" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = 50256 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xgen") + + +class NousHermesAdapter(BaseModelAdapter): + """The model adapter for NousResearch/Nous-Hermes-13b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "nous-hermes" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class InternLMChatAdapter(BaseModelAdapter): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + if "8k" in model_path.lower(): + model.config.max_sequence_length = 8192 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("internlm-chat") + + +class StarChatAdapter(BaseModelAdapter): + """The model adapter for HuggingFaceH4/starchat-beta""" + + def match(self, model_path: str): + return "starchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("starchat") + + +class MistralAdapter(BaseModelAdapter): + """The model adapter for Mistral AI models""" + + def match(self, model_path: str): + return "mistral" in model_path.lower() or "mixtral" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("mistral") + + +class Llama2Adapter(BaseModelAdapter): + """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)""" + + def match(self, model_path: str): + return "llama-2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class Llama3Adapter(BaseModelAdapter): + """The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)""" + + def match(self, model_path: str): + return "llama-3" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-3") + + +class CuteGPTAdapter(BaseModelAdapter): + """The model adapter for CuteGPT""" + + def match(self, model_path: str): + return "cutegpt" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("") + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("cutegpt") + + +class OpenOrcaAdapter(BaseModelAdapter): + """Model adapter for Open-Orca models which may use different prompt templates + - (e.g. Open-Orca/OpenOrcaxOpenChat-Preview2-13B, Open-Orca/Mistral-7B-OpenOrca) + - `OpenOrcaxOpenChat-Preview2-13B` uses their "OpenChat Llama2 V1" prompt template. + - [Open-Orca/OpenOrcaxOpenChat-Preview2-13B #Prompt Template](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B#prompt-template) + - `Mistral-7B-OpenOrca` uses the [OpenAI's Chat Markup Language (ChatML)](https://github.com/openai/openai-python/blob/main/chatml.md) + format, with <|im_start|> and <|im_end|> tokens added to support this. + - [Open-Orca/Mistral-7B-OpenOrca #Prompt Template](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template) + """ + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "mistral-7b-openorca" in model_path.lower() + or "openorca" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "mistral-7b-openorca" in model_path.lower(): + return get_conv_template("mistral-7b-openorca") + return get_conv_template("open-orca") + + +class DolphinAdapter(OpenOrcaAdapter): + """Model adapter for ehartford/dolphin-2.2.1-mistral-7b""" + + def match(self, model_path: str): + return "dolphin" in model_path.lower() and "mistral" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolphin-2.2.1-mistral-7b") + + +class Hermes2Adapter(BaseModelAdapter): + """Model adapter for teknium/OpenHermes-2.5-Mistral-7B and teknium/OpenHermes-2-Mistral-7B models""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["openhermes-2.5-mistral-7b", "openhermes-2-mistral-7b"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("OpenHermes-2.5-Mistral-7B") + + +class NousHermes2MixtralAdapter(BaseModelAdapter): + """Model adapter for NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO model""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in [ + "nous-hermes-2-mixtral-8x7b-dpo", + "nous-hermes-2-mixtral-8x7b-sft", + ] + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Nous-Hermes-2-Mixtral-8x7B-DPO") + + +class WizardCoderAdapter(BaseModelAdapter): + """The model adapter for WizardCoder (e.g., WizardLM/WizardCoder-Python-34B-V1.0)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardcoder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + # Same as Alpaca, see : + # https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/inference_wizardcoder.py#L60 + return get_conv_template("alpaca") + + +class QwenChatAdapter(BaseModelAdapter): + """The model adapter for Qwen/Qwen-7B-Chat + To run this model, you need to ensure additional flash attention installation: + ``` bash + git clone https://github.com/Dao-AILab/flash-attention + cd flash-attention && pip install . + pip install csrc/layer_norm + pip install csrc/rotary + ``` + + Since from 2.0, the following change happened + - `flash_attn_unpadded_func` -> `flash_attn_varlen_func` + - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` + - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + You may need to revise the code in: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L69 + to from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + """ + + def match(self, model_path: str): + return "qwen" in model_path.lower() + + def float_set(self, config, option): + config.bf16 = False + config.fp16 = False + config.fp32 = False + + if option == "bf16": + config.bf16 = True + elif option == "fp16": + config.fp16 = True + elif option == "fp32": + config.fp32 = True + else: + print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from transformers.generation import GenerationConfig + + revision = from_pretrained_kwargs.get("revision", "main") + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=True, + ) + # NOTE: if you use the old version of model file, please remove the comments below + # config.use_flash_attn = False + self.float_set(config, "fp16") + generation_config = GenerationConfig.from_pretrained( + model_path, trust_remote_code=True + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + config=config, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ).eval() + if hasattr(model.config, "use_dynamic_ntk") and model.config.use_dynamic_ntk: + model.config.max_sequence_length = 16384 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + tokenizer.eos_token_id = config.eos_token_id + tokenizer.bos_token_id = config.bos_token_id + tokenizer.pad_token_id = generation_config.pad_token_id + model.config.eos_token_id = tokenizer.eos_token_id + model.config.bos_token_id = tokenizer.bos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("qwen-7b-chat") + + +class SmaugChatAdapter(BaseModelAdapter): + """The model adapter for abacusai/Smaug-2-72B.""" + + def match(self, model_path: str): + return "smaug" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("qwen-7b-chat") + + +class BGEAdapter(BaseModelAdapter): + """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "bge" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + model.use_cls_pooling = True + model.eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class E5Adapter(BaseModelAdapter): + """The model adapter for E5 (e.g., intfloat/e5-large-v2)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "e5-" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class AquilaChatAdapter(BaseModelAdapter): + """The model adapter for BAAI/Aquila + + Now supports: + - BAAI/AquilaChat-7B + - BAAI/AquilaChat2-7B + - BAAI/AquilaChat2-34B + """ + + def match(self, model_path: str): + return "aquila" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + # See: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347 + if "aquilachat2" in model_path: + if "16k" in model_path: + return get_conv_template("aquila") + elif "34b" in model_path: + return get_conv_template("aquila-legacy") + else: + return get_conv_template("aquila-v1") + else: + return get_conv_template("aquila-chat") + + +class Lamma2ChineseAdapter(BaseModelAdapter): + """The model adapter for FlagAlpha/LLama2-Chinese sft""" + + def match(self, model_path: str): + return "llama2-chinese" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama2-chinese") + + +class Lamma2ChineseAlpacaAdapter(BaseModelAdapter): + """The model adapter for ymcui/Chinese-LLaMA-Alpaca sft""" + + def match(self, model_path: str): + return "chinese-alpaca" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chinese-alpaca2") + + +class VigogneAdapter(BaseModelAdapter): + """The model adapter for vigogne (e.g., bofenghuang/vigogne-2-7b-chat)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return bool(re.search(r"vigogne|vigostral", model_path, re.I)) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "chat" in model_path.lower(): + if "vigostral" in model_path.lower(): + return get_conv_template("vigogne_chat_v3") + return get_conv_template("vigogne_chat_v2") + return get_conv_template("vigogne_instruct") + + +class OpenLLaMaOpenInstructAdapter(BaseModelAdapter): + """The model adapter for OpenLLaMa-Open-Instruct (e.g., VMware/open-llama-7b-open-instruct)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "open-llama" in model_path.lower() and "open-instruct" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class CodeLlamaAdapter(BaseModelAdapter): + """The model adapter for CodeLlama (e.g., codellama/CodeLlama-34b-hf)""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class StableVicunaAdapter(BaseModelAdapter): + """The model adapter for StableVicuna""" + + def match(self, model_path: str): + return "stable-vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stable-vicuna") + + +class PhindCodeLlamaAdapter(CodeLlamaAdapter): + """The model adapter for Phind-CodeLlama (e.g., Phind/Phind-CodeLlama-34B-v2)""" + + def match(self, model_path: str): + return "phind-codellama-" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phind") + + +class Llama2ChangAdapter(Llama2Adapter): + """The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)""" + + def match(self, model_path: str): + return "llama2-ko-chang" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class ZephyrAdapter(BaseModelAdapter): + """The model adapter for Zephyr (e.g. HuggingFaceH4/zephyr-7b-alpha)""" + + def match(self, model_path: str): + return "zephyr" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class NotusAdapter(BaseModelAdapter): + """The model adapter for Notus (e.g. argilla/notus-7b-v1)""" + + def match(self, model_path: str): + return "notus" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class CatPPTAdapter(BaseModelAdapter): + """The model adapter for CatPPT (e.g. rishiraj/CatPPT)""" + + def match(self, model_path: str): + return "catppt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("catppt") + + +class TinyLlamaAdapter(BaseModelAdapter): + """The model adapter for TinyLlama (e.g. TinyLlama/TinyLlama-1.1B-Chat-v1.0)""" + + def match(self, model_path: str): + return "tinyllama" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("TinyLlama") + + +class XwinLMAdapter(BaseModelAdapter): + """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return "xwin-lm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class LemurAdapter(BaseModelAdapter): + """The model adapter for OpenLemur/lemur-70b-chat-v1""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "lemur-70b-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("lemur-70b-chat") + + +class PygmalionAdapter(BaseModelAdapter): + """The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return bool( + re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I) + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metharme") + + +class XdanAdapter(BaseModelAdapter): + """The model adapter for xDAN-AI (e.g. xDAN-AI/xDAN-L1-Chat-RL-v1)""" + + def match(self, model_path: str): + return "xdan" in model_path.lower() and "v1" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xdan-v1") + + +class MicrosoftOrcaAdapter(BaseModelAdapter): + """The model adapter for Microsoft/Orca-2 series of models (e.g. Microsoft/Orca-2-7b, Microsoft/Orca-2-13b)""" + + use_fast_tokenizer = False # Flag neeeded since tokenizers>=0.13.3 is required for a normal functioning of this module + + def match(self, model_path: str): + return "orca-2" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("orca-2") + + +class YiAdapter(BaseModelAdapter): + """The model adapter for Yi models""" + + def match(self, model_path: str): + return "yi-" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Yi-34b-chat") + + +class DeepseekCoderAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's coder models""" + + def match(self, model_path: str): + return "deepseek-coder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-coder") + + +class DeepseekChatAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's chat models""" + + # Note: that this model will require tokenizer version >= 0.13.3 because the tokenizer class is LlamaTokenizerFast + + def match(self, model_path: str): + return "deepseek-llm" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-chat") + + +class GeminiAdapter(BaseModelAdapter): + """The model adapter for Gemini""" + + def match(self, model_path: str): + return "gemini" in model_path.lower() or "bard" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "gemini-1.5-pro" in model_path: + return get_conv_template("gemini-1.5-pro") + return get_conv_template("gemini") + + +class Yuan2Adapter(BaseModelAdapter): + """The model adapter for Yuan2.0""" + + def match(self, model_path: str): + return "yuan2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # from_pretrained_kwargs["torch_dtype"] = torch.bfloat16 + tokenizer = LlamaTokenizer.from_pretrained( + model_path, + add_eos_token=False, + add_bos_token=False, + eos_token="", + eod_token="", + sep_token="", + revision=revision, + ) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + # device_map='auto', + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan2") + + +class MetaMathAdapter(BaseModelAdapter): + """The model adapter for MetaMath models""" + + def match(self, model_path: str): + return "metamath" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metamath") + + +class BagelAdapter(BaseModelAdapter): + """Model adapter for jondurbin/bagel-* models""" + + def match(self, model_path: str): + return "bagel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("airoboros_v3") + + +class SolarAdapter(BaseModelAdapter): + """The model adapter for upstage/SOLAR-10.7B-Instruct-v1.0""" + + def match(self, model_path: str): + return "solar-" in model_path.lower() and "instruct" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("solar") + + +class SteerLMAdapter(BaseModelAdapter): + """The model adapter for nvidia/Llama2-70B-SteerLM-Chat""" + + def match(self, model_path: str): + return "steerlm-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("steerlm") + + +class GemmaAdapter(BaseModelAdapter): + """The model adapter for google/gemma""" + + def match(self, model_path: str): + return "gemma" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemma") + + +class LlavaAdapter(BaseModelAdapter): + """The model adapter for liuhaotian/llava-v1.5 series of models""" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + # TODO(chris): Implement huggingface-compatible load_model + pass + + def match(self, model_path: str): + return "llava" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "34b" in model_path: + return get_conv_template("llava-chatml") + + return get_conv_template("vicuna_v1.1") + + +class YuanAdapter(BaseModelAdapter): + """The model adapter for Yuan""" + + def match(self, model_path: str): + return "yuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan") + + +class OlmoAdapter(BaseModelAdapter): + """The model adapter for allenai/OLMo-7B-Instruct""" + + def match(self, model_path: str): + return "olmo" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +class YandexGPTAdapter(BaseModelAdapter): + """The model adapter for YandexGPT""" + + def match(self, model_path: str): + return "yandexgpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yandexgpt") + + +class CllmAdapter(BaseModelAdapter): + """The model adapter for CLLM""" + + def match(self, model_path: str): + return "consistency-llm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + config = AutoConfig.from_pretrained( + model_path, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, + model_max_length=2048, + padding_side="right", + ) + + model = AutoModelForCausalLM.from_pretrained( + model_path, + config=config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map="cuda", + ) + + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("cllm") + + +class CohereAdapter(BaseModelAdapter): + """The model adapter for Cohere""" + + def match(self, model_path: str): + return model_path in ["command-r"] + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +class DBRXAdapter(BaseModelAdapter): + """The model adapter for Cohere""" + + def match(self, model_path: str): + return model_path in ["dbrx-instruct"] + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +class RekaAdapter(BaseModelAdapter): + """The model adapter for Reka""" + + def match(self, model_path: str): + return "reka" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("api_based_default") + + +# Note: the registration order matters. +# The one registered earlier has a higher matching priority. +register_model_adapter(PeftModelAdapter) +register_model_adapter(StableVicunaAdapter) +register_model_adapter(VicunaAdapter) +register_model_adapter(AiroborosAdapter) +register_model_adapter(LongChatAdapter) +register_model_adapter(GoogleT5Adapter) +register_model_adapter(KoalaAdapter) +register_model_adapter(AlpacaAdapter) +register_model_adapter(ChatGLMAdapter) +register_model_adapter(CodeGeexAdapter) +register_model_adapter(DollyV2Adapter) +register_model_adapter(OasstPythiaAdapter) +register_model_adapter(OasstLLaMAAdapter) +register_model_adapter(OpenChat35Adapter) +register_model_adapter(TenyxChatAdapter) +register_model_adapter(StableLMAdapter) +register_model_adapter(BaizeAdapter) +register_model_adapter(RwkvAdapter) +register_model_adapter(OpenBuddyAdapter) +register_model_adapter(PhoenixAdapter) +register_model_adapter(BardAdapter) +register_model_adapter(PaLM2Adapter) +register_model_adapter(GeminiAdapter) +register_model_adapter(GeminiDevAdapter) +register_model_adapter(GemmaAdapter) +register_model_adapter(ChatGPTAdapter) +register_model_adapter(AzureOpenAIAdapter) +register_model_adapter(ClaudeAdapter) +register_model_adapter(MPTAdapter) +register_model_adapter(BiLLaAdapter) +register_model_adapter(RedPajamaINCITEAdapter) +register_model_adapter(H2OGPTAdapter) +register_model_adapter(RobinAdapter) +register_model_adapter(SnoozyAdapter) +register_model_adapter(WizardLMAdapter) +register_model_adapter(ManticoreAdapter) +register_model_adapter(GuanacoAdapter) +register_model_adapter(CamelAdapter) +register_model_adapter(ChangGPTAdapter) +register_model_adapter(TuluAdapter) +register_model_adapter(FalconChatAdapter) +register_model_adapter(FalconAdapter) +register_model_adapter(TigerBotAdapter) +register_model_adapter(BaichuanAdapter) +register_model_adapter(XGenAdapter) +register_model_adapter(PythiaAdapter) +register_model_adapter(InternLMChatAdapter) +register_model_adapter(StarChatAdapter) +register_model_adapter(Llama2Adapter) +register_model_adapter(CuteGPTAdapter) +register_model_adapter(OpenOrcaAdapter) +register_model_adapter(DolphinAdapter) +register_model_adapter(Hermes2Adapter) +register_model_adapter(NousHermes2MixtralAdapter) +register_model_adapter(NousHermesAdapter) +register_model_adapter(MistralAdapter) +register_model_adapter(WizardCoderAdapter) +register_model_adapter(QwenChatAdapter) +register_model_adapter(AquilaChatAdapter) +register_model_adapter(BGEAdapter) +register_model_adapter(E5Adapter) +register_model_adapter(Lamma2ChineseAdapter) +register_model_adapter(Lamma2ChineseAlpacaAdapter) +register_model_adapter(VigogneAdapter) +register_model_adapter(OpenLLaMaOpenInstructAdapter) +register_model_adapter(ReaLMAdapter) +register_model_adapter(PhindCodeLlamaAdapter) +register_model_adapter(CodeLlamaAdapter) +register_model_adapter(Llama2ChangAdapter) +register_model_adapter(ZephyrAdapter) +register_model_adapter(NotusAdapter) +register_model_adapter(CatPPTAdapter) +register_model_adapter(TinyLlamaAdapter) +register_model_adapter(XwinLMAdapter) +register_model_adapter(LemurAdapter) +register_model_adapter(PygmalionAdapter) +register_model_adapter(MicrosoftOrcaAdapter) +register_model_adapter(XdanAdapter) +register_model_adapter(YiAdapter) +register_model_adapter(PplxAIAdapter) +register_model_adapter(DeepseekCoderAdapter) +register_model_adapter(DeepseekChatAdapter) +register_model_adapter(Yuan2Adapter) +register_model_adapter(MetaMathAdapter) +register_model_adapter(BagelAdapter) +register_model_adapter(SolarAdapter) +register_model_adapter(SteerLMAdapter) +register_model_adapter(LlavaAdapter) +register_model_adapter(YuanAdapter) +register_model_adapter(OlmoAdapter) +register_model_adapter(CohereAdapter) +register_model_adapter(DBRXAdapter) +register_model_adapter(GemmaAdapter) +register_model_adapter(YandexGPTAdapter) +register_model_adapter(CllmAdapter) +register_model_adapter(RekaAdapter) +register_model_adapter(SmaugChatAdapter) +register_model_adapter(Llama3Adapter) + +# After all adapters, try the default base adapter. +register_model_adapter(BaseModelAdapter) diff --git a/fastchat/model/model_chatglm.py b/fastchat/model/model_chatglm.py new file mode 100644 index 0000000000000000000000000000000000000000..2cbac8bc5f9f5ccbee833ac9cc22cf23c068e51e --- /dev/null +++ b/fastchat/model/model_chatglm.py @@ -0,0 +1,137 @@ +""" +Inference code for ChatGLM. +Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. +""" +import re + +import torch +from transformers.generation.logits_process import LogitsProcessor + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +invalid_score_processor = InvalidScoreLogitsProcessor() + + +def process_response(response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + +def recover_message_list(prompt): + role_token_pattern = "|".join( + [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] + ) + role = None + last_end_idx = -1 + message_list = [] + for match in re.finditer(role_token_pattern, prompt): + if role: + messge = {} + if role == "<|system|>": + messge["role"] = "system" + elif role == "<|user|>": + messge["role"] = "user" + else: + messge["role"] = "assistant" + messge["content"] = prompt[last_end_idx + 1 : match.start()] + message_list.append(messge) + + role = prompt[match.start() : match.end()] + last_end_idx = match.end() + + return message_list + + +@torch.inference_mode() +def generate_stream_chatglm( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = params.get("echo", True) + + model_type = str(type(model)).lower() + if "peft" in model_type: + model_type = str(type(model.base_model.model)).lower() + + if "chatglm3" in model_type: + message_list = recover_message_list(prompt) + inputs = tokenizer.build_chat_input( + query=message_list[-1]["content"], history=message_list[:-1], role="user" + ).to(model.device) + else: + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + input_echo_len = len(inputs["input_ids"][0]) + + gen_kwargs = { + "max_length": max_new_tokens + input_echo_len, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "logits_processor": [invalid_score_processor], + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + total_len = 0 + for total_ids in model.stream_generate(**inputs, **gen_kwargs): + total_ids = total_ids.tolist()[0] + total_len = len(total_ids) + if echo: + output_ids = total_ids + else: + output_ids = total_ids[input_echo_len:] + response = tokenizer.decode(output_ids) + response = process_response(response) + + yield { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": None, + } + + # TODO: ChatGLM stop when it reach max length + # Only last stream result contains finish_reason, we set finish_reason as stop + ret = { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": "stop", + } + yield ret diff --git a/fastchat/model/model_cllm.py b/fastchat/model/model_cllm.py new file mode 100644 index 0000000000000000000000000000000000000000..563e2a5598233788a4a165eb41371fd7ab729f62 --- /dev/null +++ b/fastchat/model/model_cllm.py @@ -0,0 +1,202 @@ +import torch +import gc + +import os +import time +import random +from typing import Dict, Optional, Sequence, List, Tuple +from transformers.cache_utils import Cache, DynamicCache +from transformers import ( + LlamaModel, + LlamaForCausalLM, + GenerationConfig, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, +) +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +import torch.nn.functional as F + + +def get_jacobian_trajectory( + model, tokenizer, input_ids, attention_mask, max_new_tokens +): + bsz = input_ids.shape[0] + prompt_len = [torch.sum(t) for t in attention_mask] + max_prompt_len = max(prompt_len) + total_len = max_prompt_len + max_new_tokens + + # initialize the first point of jacobian trajectory + tokens = torch.full( + (bsz, total_len), tokenizer.pad_token_id, dtype=torch.long, device=model.device + ) + for i in range(bsz): + tokens[i, :] = torch.tensor( + random.choices(input_ids[i][attention_mask[i] == 1], k=total_len), + dtype=torch.long, + device=model.device, + ) + tokens[i, : prompt_len[i]] = input_ids[i][: prompt_len[i]].to( + dtype=torch.long, device=model.device + ) + itr = 0 + next_generation = tokens + generate_attention_mask = torch.full_like(next_generation, 1).to(model.device) + accurate_lengths = torch.tensor([prompt_len[i].item()] * bsz, device=model.device) + prev_len = 0 + while True: + current_generation = next_generation + with torch.no_grad(): + logits = model(current_generation, generate_attention_mask).logits + next_generation = torch.argmax( + torch.nn.functional.softmax(logits, dim=-1) / 0.001, dim=-1 + ) + + # hold prompt unchanged and update generated tokens + for i in range(bsz): + next_generation[i, :] = torch.cat( + ( + tokens[i, : prompt_len[i]], + next_generation[i, prompt_len[i] - 1 : total_len - 1], + ), + dim=0, + ) + + if ( + torch.all(torch.eq(next_generation, current_generation)).item() + and itr == max_new_tokens + or len( + torch.where( + current_generation[0, : accurate_lengths[0]] + == tokenizer.eos_token_id + )[0] + ) + > 0 + ): + # forced exit due to max_new_tokens constraint or eos reached + return next_generation, itr + + # skip the first itr, current_generation has not been updated yet + if itr != 0: + if torch.all(torch.eq(next_generation, current_generation)).item(): + matched_position = total_len + else: + matched_position = ( + torch.eq(current_generation, next_generation).squeeze(0) == False + ).nonzero(as_tuple=True)[0][0] + fast_forward_cnt = matched_position - accurate_lengths[0] + + for i in range(bsz): + accurate_lengths[i] = matched_position.item() + + # flush and print the first sequence + generated_str = tokenizer.decode( + next_generation[0, prompt_len[0] : accurate_lengths[0]], + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + print(generated_str[prev_len:], flush=True, end="") + prev_len = len(generated_str) + + if torch.all(torch.eq(next_generation, current_generation)).item(): + # early termination: itr < max_new_tokens + return next_generation, itr + + itr += 1 + + +def generate_stream_cllm( + model, + tokenizer, + params, + device, + context_len, + stream_interval=2, + judge_sent_end=False, +): + # converge_step = [] + prompt = params["prompt"] + inputs = tokenizer(prompt, return_tensors="pt").to(device) + max_new_tokens = int(params.get("n_token_seq_length", 32)) + max_new_seq_len = int(params.get("max_new_tokens", 1024)) + + prompt_len = torch.sum(inputs["attention_mask"], dim=-1) + generation = inputs["input_ids"] + input_echo_len = len(generation) + + ### generation phase + itr = 0 + eos_reached = False + while True: + if itr == 0: + input_ids = inputs["input_ids"] + input_masks = inputs["attention_mask"] + else: + input_masks = torch.ones_like(input_ids).to(device) + for j in range(bsz): + input_masks[j][ + torch.sum(inputs["attention_mask"], dim=-1)[j] + + itr * max_new_tokens : + ] = 0 + + bsz = input_ids.shape[0] + eos_reached = torch.tensor([False] * bsz, device=device) + + generation, iter_steps = get_jacobian_trajectory( + model=model, + tokenizer=tokenizer, + input_ids=input_ids, + attention_mask=input_masks, + max_new_tokens=max_new_tokens, + ) + + ### inspect + for j in range(bsz): + prompt_len = torch.sum(input_masks, dim=-1) + eos_positions = torch.where(generation[j] == tokenizer.eos_token_id)[0] + + if len(eos_positions) == 0: + # no EOS, continue to the next item in the batch + generation[j][prompt_len[j] + max_new_tokens :] = tokenizer.pad_token_id + continue + # otherwise, set tokens coming after EOS as pad + else: + if len(eos_positions) != 0: + eos_reached[j] = True + generation[j, int(eos_positions[0]) + 1 :] = tokenizer.pad_token_id + + itr += 1 + + if all(eos_reached) or itr * max_new_tokens >= max_new_seq_len: + break + input_ids = generation[ + torch.where(eos_reached == False)[0].tolist(), ... + ] # delete samples with generated + + if all(eos_reached): + finish_reason = "eos" + elif itr * max_new_tokens > max_new_seq_len: + finish_reason = "length" + else: + finish_reason = "stop" + + output = tokenizer.decode(input_ids[0], skip_special_tokens=False) + + yield { + "text": "", + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": itr * max_new_tokens, + "total_tokens": input_echo_len + itr * max_new_tokens, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/fastchat/model/model_codet5p.py b/fastchat/model/model_codet5p.py new file mode 100644 index 0000000000000000000000000000000000000000..0984513c96931b6d48dfd17f3020fe5cebc3f911 --- /dev/null +++ b/fastchat/model/model_codet5p.py @@ -0,0 +1,108 @@ +import gc +from threading import Thread +import torch +import transformers +from transformers import ( + GenerationConfig, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, +) + + +@torch.inference_mode() +def generate_stream_codet5p( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 1024)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, **decode_config) + encoding = tokenizer(prompt, return_tensors="pt").to(device) + input_ids = encoding.input_ids + encoding["decoder_input_ids"] = encoding["input_ids"].clone() + input_echo_len = len(input_ids) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + class CodeBlockStopper(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + # Code-completion is open-end generation. + # We check \n\n to stop at end of a code block. + if list(input_ids[0][-2:]) == [628, 198]: + return True + return False + + gen_kwargs = dict( + **encoding, + streamer=streamer, + generation_config=generation_config, + stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), + ) + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + i = 0 + output = "" + for new_text in streamer: + i += 1 + output += new_text + if i % stream_interval == 0 or i == max_new_tokens - 1: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + if i >= max_new_tokens: + break + + if i >= max_new_tokens: + finish_reason = "length" + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + thread.join() + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/fastchat/model/model_exllama.py b/fastchat/model/model_exllama.py new file mode 100644 index 0000000000000000000000000000000000000000..306edab21a79658d22eb75f1da3eba1f830e4ae7 --- /dev/null +++ b/fastchat/model/model_exllama.py @@ -0,0 +1,77 @@ +import gc +import sys +from typing import Dict + +import torch + + +def generate_stream_exllama( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + try: + from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + prompt = params["prompt"] + + generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) + settings = ExLlamaV2Sampler.Settings() + + settings.temperature = float(params.get("temperature", 0.85)) + settings.top_k = int(params.get("top_k", 50)) + settings.top_p = float(params.get("top_p", 0.8)) + settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) + settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) + + max_new_tokens = int(params.get("max_new_tokens", 256)) + + generator.set_stop_conditions(params.get("stop_token_ids", None) or []) + echo = bool(params.get("echo", True)) + + input_ids = generator.tokenizer.encode(prompt) + prompt_tokens = input_ids.shape[-1] + generator.begin_stream(input_ids, settings) + + generated_tokens = 0 + if echo: + output = prompt + else: + output = "" + while True: + chunk, eos, _ = generator.stream() + output += chunk + generated_tokens += 1 + if generated_tokens == max_new_tokens: + finish_reason = "length" + break + elif eos: + finish_reason = "length" + break + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": None, + } + + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/fastchat/model/model_falcon.py b/fastchat/model/model_falcon.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8af8efa20bd29fb31cdd0a0bc039b30f4bf26e --- /dev/null +++ b/fastchat/model/model_falcon.py @@ -0,0 +1,140 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_falcon( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + stop_str = params.get("stop", None) + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..75b47a705c54f1aca8a468f45b03cfc2715a49b6 --- /dev/null +++ b/fastchat/model/model_registry.py @@ -0,0 +1,888 @@ +"""Additional information of the models.""" +from collections import namedtuple, OrderedDict +from typing import List + + +ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) + + +model_info = OrderedDict() + + +def register_model_info( + full_names: List[str], simple_name: str, link: str, description: str +): + info = ModelInfo(simple_name, link, description) + + for full_name in full_names: + model_info[full_name] = info + + +def get_model_info(name: str) -> ModelInfo: + if name in model_info: + return model_info[name] + else: + # To fix this, please use `register_model_info` to register your model + return ModelInfo( + name, "", "Register the description at fastchat/model/model_registry.py" + ) + + +register_model_info( + [ + "IEITYuan/Yuan2-2B-Janus-hf", + "IEITYuan/Yuan2-2B-hf", + "IEITYuan/Yuan2-51B-hf", + "IEITYuan/Yuan2-102B-hf", + ], + "IEIT-Yuan2", + "https://github.com/IEIT-Yuan/Yuan-2.0", + "Yuan2.0 is a new generation Fundamental Large Language Model developed by IEIT System.", +) + +register_model_info( + [ + "claude-3-5-sonnet-20240620", + "claude-3-haiku-20240307", + "claude-3-sonnet-20240229", + "claude-3-opus-20240229", + "claude-2.1", + "claude-2.0", + "claude-1", + ], + "Claude", + "https://www.anthropic.com/news/claude-3-family", + "Claude by Anthropic", +) + +register_model_info( + ["gemma-2-27b-it", "gemma-2-9b-it"], + "Gemma 2", + "https://ai.google.dev/gemma", + "Gemma 2 by Google", +) + +register_model_info( + [ + "gemini-1.5-pro-api-0409-preview", + "gemini-1.5-pro-tuned", + "gemini-1.5-pro-api-preview", + "gemini-1.5-flash-api-preview", + "gemini-1.5-flash-api-0514", + "gemini-1.5-pro-api-0514", + "gemini-advanced-0514", + ], + "Gemini", + "https://deepmind.google/technologies/gemini/", + "Gemini by Google", +) + +register_model_info( + ["deepseek-coder-v2"], + "DeepSeek Coder v2", + "https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct", + "An advanced code model by DeepSeek", +) + +register_model_info( + ["nemotron-4-340b"], + "Nemotron-4 340B", + "https://huggingface.co/nvidia/Nemotron-4-340B-Instruct", + "Cutting-edge Open model by Nvidia", +) + +register_model_info( + ["llama-3-70b-instruct", "llama-3-8b-instruct"], + "Llama 3", + "https://ai.meta.com/blog/meta-llama-3/", + "Open foundation and chat models by Meta", +) + +register_model_info( + [ + "qwen2-72b-instruct", + "qwen-max-0403", + "qwen-max-0428", + ], + "Qwen Max", + "https://help.aliyun.com/zh/dashscope/developer-reference/model-introduction", + "The Frontier Qwen Model by Alibaba", +) + +register_model_info( + [ + "gpt-4o-2024-05-13", + ], + "GPT-4o", + "https://openai.com/index/hello-gpt-4o/", + "The flagship model across audio, vision, and text by OpenAI", +) + +register_model_info( + [ + "gpt-4-turbo", + "gpt-4o-2024-05-13", + "gpt-4-turbo-2024-04-09", + "gpt-4-1106-preview", + "gpt-4-0125-preview", + "gpt2-chatbot", + "im-also-a-good-gpt2-chatbot", + "im-a-good-gpt2-chatbot", + ], + "GPT-4-Turbo", + "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", + "GPT-4-Turbo by OpenAI", +) + +register_model_info( + [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0314", + "gpt-3.5-turbo-0613", + ], + "GPT-3.5", + "https://platform.openai.com/docs/models/gpt-3-5", + "GPT-3.5-Turbo by OpenAI", +) + +register_model_info( + ["yi-large-preview", "yi-large"], + "Yi-Large", + "https://x.com/01AI_Yi/status/1789894091620458667", + "State-of-the-art model by 01 AI", +) + +register_model_info( + ["yi-1.5-34b-chat", "yi-34b-chat", "yi-6b-chat"], + "Yi-Chat", + "https://huggingface.co/01-ai/", + "A large language model by 01 AI", +) + +register_model_info( + [ + "phi-3-medium-4k-instruct", + "phi-3-small-8k-instruct", + "phi-3-mini-4k-instruct", + "phi-3-mini-128k-instruct", + ], + "Phi-3", + "https://azure.microsoft.com/en-us/blog/introducing-phi-3-redefining-whats-possible-with-slms/", + "A capable and cost-effective small language models (SLMs) by Microsoft", +) + +register_model_info( + ["reka-core-20240501"], + "Reka Core", + "https://www.reka.ai/news/reka-core-our-frontier-class-multimodal-language-model", + "Frontier Multimodal Language Model by Reka", +) + +register_model_info( + ["gemini-pro", "gemini-pro-dev-api"], + "Gemini", + "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", + "Gemini by Google", +) + +register_model_info( + ["reka-flash-preview-20240611", "reka-flash", "reka-flash-online"], + "Reka Flash", + "https://www.reka.ai/news/reka-flash-efficient-and-capable-multimodal-language-models", + "Multimodal model by Reka", +) + +register_model_info( + ["command-r-plus"], + "Command-R-Plus", + "https://txt.cohere.com/command-r-plus-microsoft-azure/", + "Command-R Plus by Cohere", +) + +register_model_info( + ["command-r"], + "Command-R", + "https://txt.cohere.com/command-r/", + "Command-R by Cohere", +) + +register_model_info( + ["gemma-1.1-7b-it", "gemma-1.1-2b-it", "gemma-7b-it", "gemma-2b-it"], + "Gemma", + "https://blog.google/technology/developers/gemma-open-models/", + "Gemma by Google", +) + +register_model_info( + [ + "codestral-2405", + "mixtral-8x22b-instruct-v0.1", + ], + "Mixtral of experts", + "https://mistral.ai/news/mixtral-8x22b/", + "A Mixture-of-Experts model by Mistral AI", +) + +register_model_info( + [ + "mixtral-8x7b-instruct-v0.1", + "mistral-large-2402", + "mistral-medium", + "mistral-next", + "mistral-7b-instruct-v0.2", + "mistral-7b-instruct", + ], + "Mixtral of experts", + "https://mistral.ai/news/mixtral-of-experts/", + "A Mixture-of-Experts model by Mistral AI", +) + +register_model_info( + [ + "qwen1.5-110b-chat", + ], + "Qwen 1.5", + "https://qwenlm.github.io/blog/qwen1.5-110b/", + "The First 100B+ Model of the Qwen1.5 Series", +) + +register_model_info( + [ + "qwen1.5-72b-chat", + "qwen1.5-32b-chat", + "qwen1.5-14b-chat", + "qwen1.5-7b-chat", + "qwen1.5-4b-chat", + "qwen1.5-1.8b-chat", + "qwen1.5-0.5b-chat", + "qwen-14b-chat", + ], + "Qwen 1.5", + "https://qwenlm.github.io/blog/qwen1.5/", + "A large language model by Alibaba Cloud", +) + +register_model_info( + ["glm-4-0520", "glm-4-0116"], + "GLM-4", + "https://zhipuai.cn/devday", + "Next-Gen Foundation Model by Zhipu AI", +) + +register_model_info( + ["snowflake-arctic-instruct"], + "Snowflake Arctic Instruct", + "https://www.snowflake.com/blog/arctic-open-efficient-foundation-language-models-snowflake/", + "Dense-MoE transformer by Snowflake AI", +) + +register_model_info( + ["dbrx-instruct", "dbrx-next"], + "DBRX Instruct", + "https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm", + "DBRX by Databricks Mosaic AI", +) + +register_model_info( + [ + "zephyr-orpo-141b-A35b-v0.1", + ], + "Zephyr 141B-A35B", + "https://huggingface.co/HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + "ORPO fine-tuned of Mixtral-8x22B-v0.1", +) + +register_model_info( + ["starling-lm-7b-beta", "starling-lm-7b-alpha"], + "Starling-LM-7B", + "https://starling.cs.berkeley.edu/", + "An open model trained using RLAIF by Berkeley", +) + +register_model_info( + ["qwen-14b-chat"], + "Qwen", + "https://huggingface.co/Qwen", + "A large language model by Alibaba Cloud", +) + +register_model_info( + ["bard-feb-2024", "bard-jan-24-gemini-pro"], + "Bard", + "https://bard.google.com/", + "Bard by Google", +) + +register_model_info( + [ + "gemini-1.0-pro-vision", + "gemini-1.5-pro-preview-0409", + ], + "Gemini", + "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", + "Gemini by Google", +) + +register_model_info( + ["stripedhyena-nous-7b"], + "StripedHyena-Nous", + "https://huggingface.co/togethercomputer/StripedHyena-Nous-7B", + "A chat model developed by Together Research and Nous Research.", +) + +register_model_info( + ["solar-10.7b-instruct-v1.0"], + "SOLAR-10.7B-Instruct", + "https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0", + "A model trained using depth up-scaling by Upstage AI", +) + +register_model_info( + ["gpt-4-turbo-browsing"], + "GPT-4-Turbo with browsing", + "https://platform.openai.com/docs/assistants/overview", + "GPT-4-Turbo with browsing by OpenAI", +) + +register_model_info( + ["gpt-4", "gpt-4-0314", "gpt-4-0613"], + "GPT-4", + "https://openai.com/research/gpt-4", + "GPT-4 by OpenAI", +) + +register_model_info( + ["claude-instant-1", "claude-instant-1.2"], + "Claude Instant", + "https://www.anthropic.com/index/introducing-claude", + "Claude Instant by Anthropic", +) + +register_model_info( + ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"], + "Llama 2", + "https://ai.meta.com/llama/", + "Open foundation and fine-tuned chat models by Meta", +) + +register_model_info( + ["olmo-7b-instruct"], + "OLMo-7B", + "https://huggingface.co/allenai/OLMo-7B-Instruct", + "OLMo by Allen AI", +) + +register_model_info( + [ + "vicuna-33b", + "vicuna-33b-v1.3", + "vicuna-13b", + "vicuna-13b-v1.5", + "vicuna-7b", + "vicuna-7b-v1.5", + ], + "Vicuna", + "https://lmsys.org/blog/2023-03-30-vicuna/", + "A chat assistant fine-tuned on user-shared conversations by LMSYS", +) + +register_model_info( + [ + "codellama-70b-instruct", + "codellama-34b-instruct", + "codellama-13b-instruct", + "codellama-7b-instruct", + ], + "Code Llama", + "https://ai.meta.com/blog/code-llama-large-language-model-coding/", + "Open foundation models for code by Meta", +) + +register_model_info( + ["openchat-3.5-0106", "openchat-3.5"], + "OpenChat 3.5", + "https://github.com/imoneoi/openchat", + "An open model fine-tuned on Mistral-7B using C-RLFT", +) + +register_model_info( + ["deepseek-llm-67b-chat"], + "DeepSeek LLM", + "https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat", + "An advanced language model by DeepSeek", +) + +register_model_info( + ["stripedhyena-nous-7b"], + "StripedHyena-Nous", + "https://huggingface.co/togethercomputer/StripedHyena-Nous-7B", + "A chat model developed by Together Research and Nous Research.", +) + +register_model_info( + ["nous-hermes-2-mixtral-8x7b-dpo"], + "Nous-Hermes-2-Mixtral-8x7B-DPO", + "https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", + "Nous Hermes finetuned from Mixtral 8x7B", +) + + +register_model_info( + ["llama2-70b-steerlm-chat"], + "Llama2-70B-SteerLM-Chat", + "https://huggingface.co/nvidia/Llama2-70B-SteerLM-Chat", + "A Llama fine-tuned with SteerLM method by NVIDIA", +) + +register_model_info( + ["pplx-70b-online", "pplx-7b-online"], + "pplx-online-llms", + "https://blog.perplexity.ai/blog/introducing-pplx-online-llms", + "Online LLM API by Perplexity AI", +) + +register_model_info( + ["openhermes-2.5-mistral-7b"], + "OpenHermes-2.5-Mistral-7B", + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B", + "A mistral-based model fine-tuned on 1M GPT-4 outputs", +) + +register_model_info( + ["tulu-2-dpo-70b"], + "Tulu 2", + "https://huggingface.co/allenai/tulu-2-dpo-70b", + "An instruction and RLHF model by UW/AllenAI", +) + +register_model_info( + ["chatglm3-6b", "chatglm2-6b", "chatglm-6b"], + "ChatGLM", + "https://chatglm.cn/blog", + "An open bilingual dialogue language model by Tsinghua University", +) + +register_model_info( + ["tenyxchat-7b-v1"], + "TenyxChat-7B", + "https://huggingface.co/tenyx/TenyxChat-7B-v1", + "An open model DPO trained on top of OpenChat-3.5 using Tenyx fine-tuning", +) + +register_model_info( + ["zephyr-7b-beta", "zephyr-7b-alpha"], + "Zephyr", + "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha", + "A chatbot fine-tuned from Mistral by Hugging Face", +) + +register_model_info( + ["notus-7b-v1"], + "Notus", + "https://huggingface.co/argilla/notus-7b-v1", + "A chatbot fine-tuned from Zephyr SFT by Argilla", +) + +register_model_info( + ["catppt"], + "CatPPT", + "https://huggingface.co/rishiraj/CatPPT", + "A chatbot fine-tuned from a SLERP merged model by Rishiraj Acharya", +) + +register_model_info( + ["TinyLlama"], + "TinyLlama", + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.", +) + +register_model_info( + ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM", + "An instruction-following LLM using evol-instruct by Microsoft", +) + +register_model_info( + ["wizardcoder-15b-v1.0"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder", + "Empowering Code Large Language Models with Evol-Instruct", +) + +register_model_info( + ["mpt-7b-chat", "mpt-30b-chat"], + "MPT-Chat", + "https://www.mosaicml.com/blog/mpt-30b", + "A chatbot fine-tuned from MPT by MosaicML", +) + +register_model_info( + ["guanaco-33b", "guanaco-65b"], + "Guanaco", + "https://github.com/artidoro/qlora", + "A model fine-tuned with QLoRA by UW", +) + +register_model_info( + ["gpt4all-13b-snoozy"], + "GPT4All-Snoozy", + "https://github.com/nomic-ai/gpt4all", + "A finetuned LLaMA model on assistant style data by Nomic AI", +) + +register_model_info( + ["koala-13b"], + "Koala", + "https://bair.berkeley.edu/blog/2023/04/03/koala", + "A dialogue model for academic research by BAIR", +) + +register_model_info( + ["RWKV-4-Raven-14B"], + "RWKV-4-Raven", + "https://huggingface.co/BlinkDL/rwkv-4-raven", + "An RNN with transformer-level LLM performance", +) + +register_model_info( + ["alpaca-13b"], + "Alpaca", + "https://crfm.stanford.edu/2023/03/13/alpaca.html", + "A model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", +) + +register_model_info( + ["oasst-pythia-12b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "An Open Assistant for everyone by LAION", +) + +register_model_info( + ["oasst-sft-7-llama-30b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "An Open Assistant for everyone by LAION", +) + +register_model_info( + ["palm-2"], + "PaLM 2 Chat", + "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", + "PaLM 2 for Chat (chat-bison@001) by Google", +) + +register_model_info( + ["llama-7b", "llama-13b"], + "LLaMA", + "https://arxiv.org/abs/2302.13971", + "Open and efficient foundation language models by Meta", +) + +register_model_info( + ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"], + "Open LLaMa (Open Instruct)", + "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc", + "Open LLaMa fine-tuned on instruction-following data by VMware", +) + +register_model_info( + ["dolly-v2-12b"], + "Dolly", + "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm", + "An instruction-tuned open large language model by Databricks", +) + +register_model_info( + ["stablelm-tuned-alpha-7b"], + "StableLM", + "https://github.com/stability-AI/stableLM", + "Stability AI language models", +) + +register_model_info( + ["codet5p-6b"], + "CodeT5p-6b", + "https://huggingface.co/Salesforce/codet5p-6b", + "Code completion model released by Salesforce", +) + +register_model_info( + ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"], + "FastChat-T5", + "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0", + "A chat assistant fine-tuned from FLAN-T5 by LMSYS", +) + +register_model_info( + ["phoenix-inst-chat-7b"], + "Phoenix-7B", + "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b", + "A multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", +) + +register_model_info( + ["realm-7b-v1"], + "ReaLM", + "https://github.com/FreedomIntelligence/ReaLM", + "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.", +) + +register_model_info( + ["billa-7b-sft"], + "BiLLa-7B-SFT", + "https://huggingface.co/Neutralzz/BiLLa-7B-SFT", + "An instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", +) + +register_model_info( + ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"], + "h2oGPT-GM-7b", + "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", + "An instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", +) + +register_model_info( + ["baize-v2-7b", "baize-v2-13b"], + "Baize v2", + "https://github.com/project-baize/baize-chatbot#v2", + "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.", +) + +register_model_info( + [ + "airoboros-l2-7b-2.1", + "airoboros-l2-13b-2.1", + "airoboros-c34b-2.1", + "airoboros-l2-70b-2.1", + ], + "airoboros", + "https://huggingface.co/jondurbin/airoboros-l2-70b-2.1", + "An instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4", +) + +register_model_info( + [ + "spicyboros-7b-2.2", + "spicyboros-13b-2.2", + "spicyboros-70b-2.2", + ], + "spicyboros", + "https://huggingface.co/jondurbin/spicyboros-70b-2.2", + "De-aligned versions of the airoboros models", +) + +register_model_info( + ["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"], + "Robin-v2", + "https://huggingface.co/OptimalScale/robin-7b-v2-delta", + "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.", +) + +register_model_info( + ["manticore-13b-chat"], + "Manticore 13B Chat", + "https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg", + "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.", +) + +register_model_info( + ["redpajama-incite-7b-chat"], + "RedPajama-INCITE-7B-Chat", + "https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat", + "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together", +) + +register_model_info( + [ + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "falcon-180b", + "falcon-180b-chat", + ], + "Falcon", + "https://huggingface.co/tiiuae/falcon-180B", + "TII's flagship series of large language models", +) + +register_model_info( + ["tigerbot-7b-sft"], + "Tigerbot", + "https://huggingface.co/TigerResearch/tigerbot-7b-sft", + "A large-scale language model (LLM) with multiple languages and tasks.", +) + +register_model_info( + ["internlm-chat-7b", "internlm-chat-7b-8k"], + "InternLM", + "https://huggingface.co/internlm/internlm-chat-7b", + "A multi-language large-scale language model (LLM), developed by SHLAB.", +) + +register_model_info( + ["Qwen-7B-Chat"], + "Qwen", + "https://huggingface.co/Qwen/Qwen-7B-Chat", + "A multi-language large-scale language model (LLM), developed by Damo Academy.", +) + +register_model_info( + ["smaug-2-72b"], + "Smaug-2-72B", + "https://huggingface.co/abacusai/Smaug-2-72B", + "An open model trained by Abacus.AI.", +) + +register_model_info( + ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"], + "Llama2-Chinese", + "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat", + "A multi-language large-scale language model (LLM), developed by FlagAlpha.", +) + +register_model_info( + ["Meta-Llama-3-8B-Instruct", "Meta-Llama-3-70B-Instruct"], + "llama-3", + "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + "Meta developed and released the Meta Llama 3 family of large language models (LLMs), a collection of pretrained and instruction tuned generative text models in 8 and 70B sizes.", +) + +register_model_info( + ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"], + "Chinese-Alpaca", + "https://huggingface.co/hfl/chinese-alpaca-2-13b", + "New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.", +) + +register_model_info( + ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"], + "Vigogne-Instruct", + "https://huggingface.co/bofenghuang/vigogne-2-7b-instruct", + "A French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang", +) + +register_model_info( + ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"], + "Vigogne-Chat", + "https://huggingface.co/bofenghuang/vigogne-2-7b-chat", + "A French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang", +) + +register_model_info( + ["stable-vicuna-13B-HF"], + "stable-vicuna", + "https://huggingface.co/TheBloke/stable-vicuna-13B-HF", + "A Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.", +) + +register_model_info( + ["deluxe-chat-v1", "deluxe-chat-v1.1", "deluxe-chat-v1.2", "deluxe-chat-v1.3"], + "DeluxeChat", + "", + "Deluxe Chat", +) + +register_model_info( + [ + "Xwin-LM-7B-V0.1", + "Xwin-LM-13B-V0.1", + "Xwin-LM-70B-V0.1", + "Xwin-LM-7B-V0.2", + "Xwin-LM-13B-V0.2", + ], + "Xwin-LM", + "https://github.com/Xwin-LM/Xwin-LM", + "Chat models developed by Xwin-LM team", +) + +register_model_info( + ["lemur-70b-chat"], + "Lemur-Chat", + "https://huggingface.co/OpenLemur/lemur-70b-chat-v1", + "An openly accessible language model optimized for both natural language and coding capabilities ", +) + +register_model_info( + ["Mistral-7B-OpenOrca"], + "Open-Orca", + "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca", + "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)", +) + +register_model_info( + ["dolphin-2.2.1-mistral-7b"], + "dolphin-mistral", + "https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b", + "An uncensored fine-tuned Mistral 7B", +) + +register_model_info( + [ + "AquilaChat-7B", + "AquilaChat2-7B", + "AquilaChat2-34B", + ], + "Aquila-Chat", + "https://huggingface.co/BAAI/AquilaChat2-34B", + "Chat models developed by BAAI team", +) + +register_model_info( + ["xDAN-L1-Chat-RL-v1"], + "xDAN-L1-Chat", + "https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1", + "A large language chat model created by xDAN-AI.", +) + +register_model_info( + ["MetaMath-70B-V1.0", "MetaMath-7B-V1.0"], + "MetaMath", + "https://huggingface.co/meta-math", + "A finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.", +) + +register_model_info( + ["Yuan2-2B-hf", "Yuan2-51B-hf", "Yuan2-102B-hf"], + "IEIYuan", + "https://huggingface.co/IEITYuan", + "A Basemodel developed by IEI.", +) + +register_model_info( + [ + "llava-v1.6-34b", + "llava-v1.6-vicuna-13b", + "llava-v1.6-vicuna-7b", + "llava-v1.6-mistral-7b", + "llava-v1.5-13b", + "llava-v1.5-7b", + ], + "LLaVA", + "https://github.com/haotian-liu/LLaVA", + "an open large language and vision assistant", +) + +register_model_info( + ["gemma-7b-it", "gemma-2b-it"], + "Gemma", + "https://blog.google/technology/developers/gemma-open-models/", + "Gemma by Google", +) + +register_model_info( + [ + "cllm/consistency-llm-7b-codesearchnet", + "cllm/consistency-llm-7b-gsm8k", + "cllm/consistency-llm-7b-sharegpt48k", + "cllm/consistency-llm-7b-spider", + ], + "consistency-llm", + "https://huggingface.co/cllm", + "consistency-llm is a new generation of parallel decoder LLMs with fast generation speed.", +) diff --git a/fastchat/model/model_xfastertransformer.py b/fastchat/model/model_xfastertransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..54890b1ca4977f4243cca46cb7c78114a3b2e5d6 --- /dev/null +++ b/fastchat/model/model_xfastertransformer.py @@ -0,0 +1,81 @@ +import gc +from threading import Thread + +import torch +from transformers import TextIteratorStreamer + + +@torch.inference_mode() +def generate_stream_xft( + model, + tokenizer, + params, + device, + context_len=8192, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + + # unused now, and placehold for future. + # temperature = float(params.get("temperature", 1.0)) + # top_p = float(params.get("top_p", 1.0)) + + max_new_tokens = int(params.get("max_new_tokens", 4096)) + echo = params.get("echo", True) + + inputs = tokenizer( + prompt, return_tensors="pt", padding=model.config.padding + ).input_ids + input_echo_len = len(inputs[0]) + max_len = max_new_tokens + input_echo_len + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + generation_kwargs = { + "input_ids": inputs, + "streamer": streamer, + "max_length": max_len, + "num_beams": model.config.beam_width, + "length_penalty": repetition_penalty, + "num_return_sequences": model.config.num_return_sequences, + "early_stopping": model.config.early_stopping, + "eos_token_id": model.config.eos_token_id, + "pad_token_id": model.config.pad_token_id, + } + + thread = Thread(target=model.model.generate, kwargs=generation_kwargs) + thread.start() + if echo: + # means keep the prompt + output = prompt + else: + output = "" + i = 0 + for i, new_text in enumerate(streamer): + output += new_text + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + if i == max_new_tokens - 1: + finish_reason = "length" + else: + finish_reason = "stop" + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/fastchat/model/model_yuan2.py b/fastchat/model/model_yuan2.py new file mode 100644 index 0000000000000000000000000000000000000000..25b3e13f847cb38f22bba2cf277b55cef6c10726 --- /dev/null +++ b/fastchat/model/model_yuan2.py @@ -0,0 +1,139 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_yuan2( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 0)) + top_k = int(params.get("top_k", 1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 512)) + stop_str = params.get("stop", "") + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer("")["input_ids"][0]) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1.2, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/fastchat/model/monkey_patch_non_inplace.py b/fastchat/model/monkey_patch_non_inplace.py new file mode 100644 index 0000000000000000000000000000000000000000..413dd3b30500c788abb19e5742447237ba2b1738 --- /dev/null +++ b/fastchat/model/monkey_patch_non_inplace.py @@ -0,0 +1,119 @@ +""" +Monkey patch the llama implementation in the huggingface/transformers library. +Avoid bugs in mps backend by not using in-place operations. +""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +import transformers + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2].clone() + x2 = x[..., x.shape[-1] // 2 :].clone() + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_llama_attn_with_non_inplace_operations(): + """Avoid bugs in mps backend by not using in-place operations.""" + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/fastchat/model/rwkv_model.py b/fastchat/model/rwkv_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbc14584bfd1ec90e8478b4e55f07e8ec89a967 --- /dev/null +++ b/fastchat/model/rwkv_model.py @@ -0,0 +1,76 @@ +import os +from types import SimpleNamespace +import warnings + +import torch + +os.environ["RWKV_JIT_ON"] = "1" +os.environ["RWKV_CUDA_ON"] = "1" + +from rwkv.model import RWKV +from rwkv.utils import PIPELINE, PIPELINE_ARGS + + +class RwkvModel: + def __init__(self, model_path): + warnings.warn( + "Experimental support. Please use ChatRWKV if you want to chat with RWKV" + ) + self.config = SimpleNamespace(is_encoder_decoder=False) + self.model = RWKV(model=model_path, strategy="cuda fp16") + # two GPUs + # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") + + self.tokenizer = None + self.model_path = model_path + + def to(self, target): + assert target == "cuda" + + def __call__(self, input_ids, use_cache, past_key_values=None): + assert use_cache == True + input_ids = input_ids[0].detach().cpu().numpy() + # print(input_ids) + logits, state = self.model.forward(input_ids, past_key_values) + # print(logits) + logits = logits.unsqueeze(0).unsqueeze(0) + out = SimpleNamespace(logits=logits, past_key_values=state) + return out + + def generate( + self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 + ): + # This function is used by fastchat.llm_judge. + # Because RWKV does not support huggingface generation API, + # we reuse fastchat.serve.inference.generate_stream as a workaround. + from transformers import AutoTokenizer + + from fastchat.serve.inference import generate_stream + from fastchat.conversation import get_conv_template + + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", use_fast=True + ) + prompt = self.tokenizer.decode(input_ids[0].tolist()) + conv = get_conv_template("rwkv") + + gen_params = { + "model": self.model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") + + for res in res_iter: + pass + + output = res["text"] + output_ids = self.tokenizer.encode(output) + + return [input_ids[0].tolist() + output_ids] diff --git a/fastchat/model/upload_hub.py b/fastchat/model/upload_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..b1519652e6d90479d60054008d8d7e371b16356e --- /dev/null +++ b/fastchat/model/upload_hub.py @@ -0,0 +1,45 @@ +""" +Upload weights to huggingface. + +Usage: +python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 +""" +import argparse +import tempfile + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def upload_hub(model_path, hub_repo_id, component, private): + if component == "all": + components = ["model", "tokenizer"] + else: + components = [component] + + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} + + if "model" in components: + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + with tempfile.TemporaryDirectory() as tmp_path: + model.save_pretrained(tmp_path, **kwargs) + + if "tokenizer" in components: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + with tempfile.TemporaryDirectory() as tmp_path: + tokenizer.save_pretrained(tmp_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, required=True) + parser.add_argument( + "--component", type=str, choices=["all", "model", "tokenizer"], default="all" + ) + parser.add_argument("--private", action="store_true") + args = parser.parse_args() + + upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) diff --git a/fastchat/modules/__init__.py b/fastchat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fastchat/modules/awq.py b/fastchat/modules/awq.py new file mode 100644 index 0000000000000000000000000000000000000000..1f27be85c09e2394bd821cc1ce236f46c429d4bc --- /dev/null +++ b/fastchat/modules/awq.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass, field +from pathlib import Path +import sys + +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils + + +@dataclass +class AWQConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local AWQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + + +def load_awq_quantized(model_name, awq_config: AWQConfig, device): + print("Loading AWQ quantized model...") + + try: + from tinychat.utils import load_quant + from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + except ImportError as e: + print(f"Error: Failed to import tinychat. {e}") + print("Please double check if you have successfully installed AWQ") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") + sys.exit(-1) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=False, trust_remote_code=True + ) + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.kaiming_normal_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + modeling_utils._init_weights = False + + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): + model = load_quant.load_awq_llama_fast( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + else: + model = load_quant.load_awq_model( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + return model, tokenizer + + +def find_awq_ckpt(awq_config: AWQConfig): + if Path(awq_config.ckpt).is_file(): + return awq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(awq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: AWQ checkpoint not found") + sys.exit(1) diff --git a/fastchat/modules/exllama.py b/fastchat/modules/exllama.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5fc81b3453a25905896cba31f9ce9dd0f0690e --- /dev/null +++ b/fastchat/modules/exllama.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, field +import sys + + +@dataclass +class ExllamaConfig: + max_seq_len: int + gpu_split: str = None + cache_8bit: bool = False + + +class ExllamaModel: + def __init__(self, exllama_model, exllama_cache): + self.model = exllama_model + self.cache = exllama_cache + self.config = self.model.config + + +def load_exllama_model(model_path, exllama_config: ExllamaConfig): + try: + from exllamav2 import ( + ExLlamaV2Config, + ExLlamaV2Tokenizer, + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ) + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + exllamav2_config = ExLlamaV2Config() + exllamav2_config.model_dir = model_path + exllamav2_config.prepare() + exllamav2_config.max_seq_len = exllama_config.max_seq_len + exllamav2_config.cache_8bit = exllama_config.cache_8bit + + exllama_model = ExLlamaV2(exllamav2_config) + tokenizer = ExLlamaV2Tokenizer(exllamav2_config) + + split = None + if exllama_config.gpu_split: + split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] + exllama_model.load(split) + + cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache + exllama_cache = cache_class(exllama_model) + model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) + + return model, tokenizer diff --git a/fastchat/modules/gptq.py b/fastchat/modules/gptq.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0a220c0cfb227271fbb4d1e7c4eca636b10d1c --- /dev/null +++ b/fastchat/modules/gptq.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass, field +import os +from os.path import isdir, isfile +from pathlib import Path +import sys + +from transformers import AutoTokenizer + + +@dataclass +class GptqConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local GPTQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + act_order: bool = field( + default=True, + metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, + ) + + +def load_gptq_quantized(model_name, gptq_config: GptqConfig): + print("Loading GPTQ quantized model...") + + try: + script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") + + sys.path.insert(0, module_path) + from llama import load_quant + except ImportError as e: + print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") + sys.exit(-1) + + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + # only `fastest-inference-4bit` branch cares about `act_order` + if gptq_config.act_order: + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + act_order=gptq_config.act_order, + ) + else: + # other branches + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + ) + + return model, tokenizer + + +def find_gptq_ckpt(gptq_config: GptqConfig): + if Path(gptq_config.ckpt).is_file(): + return gptq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: gptq checkpoint not found") + sys.exit(1) diff --git a/fastchat/modules/xfastertransformer.py b/fastchat/modules/xfastertransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0b49bea4cd5c9afd723318daaa5c10dcb309b776 --- /dev/null +++ b/fastchat/modules/xfastertransformer.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +import sys + + +@dataclass +class XftConfig: + max_seq_len: int = 4096 + beam_width: int = 1 + eos_token_id: int = -1 + pad_token_id: int = -1 + num_return_sequences: int = 1 + is_encoder_decoder: bool = False + padding: bool = True + early_stopping: bool = False + data_type: str = "bf16_fp16" + + +class XftModel: + def __init__(self, xft_model, xft_config): + self.model = xft_model + self.config = xft_config + + +def load_xft_model(model_path, xft_config: XftConfig): + try: + import xfastertransformer + from transformers import AutoTokenizer + except ImportError as e: + print(f"Error: Failed to load xFasterTransformer. {e}") + sys.exit(-1) + + if xft_config.data_type is None or xft_config.data_type == "": + data_type = "bf16_fp16" + else: + data_type = xft_config.data_type + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, padding_side="left", trust_remote_code=True + ) + xft_model = xfastertransformer.AutoModel.from_pretrained( + model_path, dtype=data_type + ) + model = XftModel(xft_model=xft_model, xft_config=xft_config) + if model.model.rank > 0: + while True: + model.model.generate() + return model, tokenizer diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc99449dca6551f8eb6c51dfe86eca28ea6a6be --- /dev/null +++ b/fastchat/protocol/api_protocol.py @@ -0,0 +1,172 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class APIChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + user: Optional[str] = None + repetition_penalty: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class APITokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class APITokenCheckRequest(BaseModel): + prompts: List[APITokenCheckRequestItem] + + +class APITokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class APITokenCheckResponse(BaseModel): + prompts: List[APITokenCheckResponseItem] + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[int] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[float] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..bb50a5ef0aa1d5454ce6bcc566abf15df7a1df71 --- /dev/null +++ b/fastchat/protocol/openai_api_protocol.py @@ -0,0 +1,199 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[ + str, + List[Dict[str, str]], + List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]], + ] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class TokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class TokenCheckRequest(BaseModel): + prompts: List[TokenCheckRequestItem] + + +class TokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class TokenCheckResponse(BaseModel): + prompts: List[TokenCheckResponseItem] + + +class EmbeddingsRequest(BaseModel): + model: Optional[str] = None + engine: Optional[str] = None + input: Union[str, List[Any]] + user: Optional[str] = None + encoding_format: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: str + usage: UsageInfo + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + use_beam_search: Optional[bool] = False + best_of: Optional[int] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/fastchat/serve/__init__.py b/fastchat/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b7979bedba6e0ddbcda39138de416da7460c19 --- /dev/null +++ b/fastchat/serve/api_provider.py @@ -0,0 +1,1127 @@ +"""Call API providers.""" + +import json +import os +import random +import re +from typing import Optional +import time + +import requests + +from fastchat.utils import build_logger + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + + +def get_api_provider_stream_iter( + conv, + model_name, + model_api_dict, + temperature, + top_p, + max_new_tokens, + state, +): + if model_api_dict["api_type"] == "openai": + if model_api_dict.get("vision-arena", False): + prompt = conv.to_openai_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "openai_no_stream": + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + stream=False, + ) + elif model_api_dict["api_type"] == "openai_assistant": + last_prompt = conv.messages[-2][1] + stream_iter = openai_assistant_api_stream_iter( + state, + last_prompt, + assistant_id=model_api_dict["assistant_id"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "anthropic": + if model_api_dict.get("vision-arena", False): + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "anthropic_message": + if model_api_dict.get("vision-arena", False): + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = anthropic_message_api_stream_iter( + model_api_dict["model_name"], prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "anthropic_message_vertex": + if model_api_dict.get("vision-arena", False): + prompt = conv.to_anthropic_vision_api_messages() + else: + prompt = conv.to_openai_api_messages() + stream_iter = anthropic_message_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + vertex_ai=True, + ) + elif model_api_dict["api_type"] == "gemini": + prompt = conv.to_gemini_api_messages() + stream_iter = gemini_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "gemini_no_stream": + prompt = conv.to_gemini_api_messages() + stream_iter = gemini_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_key=model_api_dict["api_key"], + use_stream=False, + ) + elif model_api_dict["api_type"] == "bard": + prompt = conv.to_openai_api_messages() + stream_iter = bard_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "mistral": + prompt = conv.to_openai_api_messages() + stream_iter = mistral_api_stream_iter( + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_key=model_api_dict.get("api_key"), + ) + elif model_api_dict["api_type"] == "nvidia": + prompt = conv.to_openai_api_messages() + stream_iter = nvidia_api_stream_iter( + model_name, + prompt, + temperature, + top_p, + max_new_tokens, + model_api_dict["api_base"], + model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "ai2": + prompt = conv.to_openai_api_messages() + stream_iter = ai2_api_stream_iter( + model_name, + model_api_dict["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "vertex": + prompt = conv.to_vertex_api_messages() + stream_iter = vertex_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_api_dict["api_type"] == "yandexgpt": + # note: top_p parameter is unused by yandexgpt + + messages = [] + if conv.system_message: + messages.append({"role": "system", "text": conv.system_message}) + messages += [ + {"role": role, "text": text} + for role, text in conv.messages + if text is not None + ] + + fixed_temperature = model_api_dict.get("fixed_temperature") + if fixed_temperature is not None: + temperature = fixed_temperature + + stream_iter = yandexgpt_api_stream_iter( + model_name=model_api_dict["model_name"], + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict.get("api_key"), + folder_id=model_api_dict.get("folder_id"), + ) + elif model_api_dict["api_type"] == "cohere": + messages = conv.to_openai_api_messages() + stream_iter = cohere_api_stream_iter( + client_name=model_api_dict.get("client_name", "FastChat"), + model_id=model_api_dict["model_name"], + messages=messages, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + elif model_api_dict["api_type"] == "reka": + messages = conv.to_reka_api_messages() + stream_iter = reka_api_stream_iter( + model_name=model_api_dict["model_name"], + messages=messages, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + api_base=model_api_dict["api_base"], + api_key=model_api_dict["api_key"], + ) + else: + raise NotImplementedError() + + return stream_iter + + +def openai_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_base=None, + api_key=None, + stream=True, +): + import openai + + api_key = api_key or os.environ["OPENAI_API_KEY"] + + if "azure" in model_name: + client = openai.AzureOpenAI( + api_version="2023-07-01-preview", + azure_endpoint=api_base or "https://api.openai.com/v1", + api_key=api_key, + ) + else: + client = openai.OpenAI( + base_url=api_base or "https://api.openai.com/v1", + api_key=api_key, + timeout=180, + ) + + # Make requests for logging + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) + + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + if stream: + res = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=True, + ) + text = "" + for chunk in res: + if len(chunk.choices) > 0: + text += chunk.choices[0].delta.content or "" + data = { + "text": text, + "error_code": 0, + } + yield data + else: + res = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=False, + ) + text = res.choices[0].message.content + pos = 0 + while pos < len(text): + # simulate token streaming + pos += 2 + time.sleep(0.001) + data = { + "text": text[:pos], + "error_code": 0, + } + yield data + + +def upload_openai_file_to_gcs(file_id): + import openai + from google.cloud import storage + + storage_client = storage.Client() + + file = openai.files.content(file_id) + # upload file to GCS + bucket = storage_client.get_bucket("arena_user_content") + blob = bucket.blob(f"{file_id}") + blob.upload_from_string(file.read()) + blob.make_public() + return blob.public_url + + +def openai_assistant_api_stream_iter( + state, + prompt, + assistant_id, + api_key=None, +): + import openai + import base64 + + api_key = api_key or os.environ["OPENAI_API_KEY"] + client = openai.OpenAI(base_url="https://api.openai.com/v1", api_key=api_key) + + if state.oai_thread_id is None: + logger.info("==== create thread ====") + thread = client.beta.threads.create() + state.oai_thread_id = thread.id + logger.info(f"==== thread_id ====\n{state.oai_thread_id}") + thread_message = client.beta.threads.messages.with_raw_response.create( + state.oai_thread_id, + role="user", + content=prompt, + timeout=3, + ) + # logger.info(f"header {thread_message.headers}") + thread_message = thread_message.parse() + # Make requests + gen_params = { + "assistant_id": assistant_id, + "thread_id": state.oai_thread_id, + "message": prompt, + } + logger.info(f"==== request ====\n{gen_params}") + + res = requests.post( + f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "OpenAI-Beta": "assistants=v1", + }, + json={"assistant_id": assistant_id, "stream": True}, + timeout=30, + stream=True, + ) + + list_of_text = [] + list_of_raw_text = [] + offset_idx = 0 + full_ret_text = "" + idx_mapping = {} + cur_offset = 0 + for line in res.iter_lines(): + if not line: + continue + data = line.decode("utf-8") + # logger.info("data:", data) + if data.endswith("[DONE]"): + break + if data.startswith("event"): + event = data.split(":")[1].strip() + if event == "thread.message.completed": + offset_idx += len(list_of_text) + continue + data = json.loads(data[6:]) + + if data.get("status") == "failed": + yield { + "text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}", + "error_code": 1, + } + return + + if data.get("status") == "completed": + logger.info(f"[debug]: {data}") + + if data["object"] != "thread.message.delta": + continue + + for delta in data["delta"]["content"]: + text_index = delta["index"] + offset_idx + if len(list_of_text) <= text_index: + list_of_text.append("") + list_of_raw_text.append("") + + text = list_of_text[text_index] + raw_text = list_of_raw_text[text_index] + + if delta["type"] == "text": + # text, url_citation or file_path + content = delta["text"] + if "annotations" in content and len(content["annotations"]) > 0: + annotations = content["annotations"] + + raw_text_copy = text + for anno in annotations: + if anno["type"] == "url_citation": + pattern = r"【\d+†source】" + matches = re.findall(pattern, content["value"]) + if len(matches) > 0: + for match in matches: + print(match) + if match not in idx_mapping: + idx_mapping[match] = len(idx_mapping) + 1 + citation_number = idx_mapping[match] + + start_idx = anno["start_index"] + cur_offset + end_idx = anno["end_index"] + cur_offset + url = anno["url_citation"]["url"] + + citation = f" [[{citation_number}]]({url})" + raw_text_copy = ( + raw_text_copy[:start_idx] + + citation + + raw_text_copy[end_idx:] + ) + cur_offset += len(citation) - (end_idx - start_idx) + elif anno["type"] == "file_path": + file_public_url = upload_openai_file_to_gcs( + anno["file_path"]["file_id"] + ) + raw_text_copy = raw_text_copy.replace( + anno["text"], f"{file_public_url}" + ) + text = raw_text_copy + else: + text_content = content["value"] + text += text_content + elif delta["type"] == "image_file": + image_public_url = upload_openai_file_to_gcs( + delta["image_file"]["file_id"] + ) + text += f"![image]({image_public_url})" + + list_of_text[text_index] = text + list_of_raw_text[text_index] = raw_text + + full_ret_text = "\n".join(list_of_text) + yield {"text": full_ret_text, "error_code": 0} + + +def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): + import anthropic + + c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = c.completions.create( + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_new_tokens, + temperature=temperature, + top_p=top_p, + model=model_name, + stream=True, + ) + text = "" + for chunk in res: + text += chunk.completion + data = { + "text": text, + "error_code": 0, + } + yield data + + +def anthropic_message_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + vertex_ai=False, +): + import anthropic + + if vertex_ai: + client = anthropic.AnthropicVertex( + region=os.environ["GCP_LOCATION"], + project_id=os.environ["GCP_PROJECT_ID"], + max_retries=5, + ) + else: + client = anthropic.Anthropic( + api_key=os.environ["ANTHROPIC_API_KEY"], + max_retries=5, + ) + + text_messages = [] + for message in messages: + if type(message["content"]) == str: # text-only model + text_messages.append(message) + else: # vision model + filtered_content_list = [ + content for content in message["content"] if content["type"] == "text" + ] + text_messages.append( + {"role": message["role"], "content": filtered_content_list} + ) + + # Make requests for logging + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + system_prompt = "" + if messages[0]["role"] == "system": + if type(messages[0]["content"]) == dict: + system_prompt = messages[0]["content"]["text"] + elif type(messages[0]["content"]) == str: + system_prompt = messages[0]["content"] + # remove system prompt + messages = messages[1:] + + text = "" + with client.messages.stream( + temperature=temperature, + top_p=top_p, + max_tokens=max_new_tokens, + messages=messages, + model=model_name, + system=system_prompt, + ) as stream: + for chunk in stream.text_stream: + text += chunk + data = { + "text": text, + "error_code": 0, + } + yield data + + +def gemini_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + use_stream=True, +): + import google.generativeai as genai # pip install google-generativeai + + if api_key is None: + api_key = os.environ["GEMINI_API_KEY"] + genai.configure(api_key=api_key) + + generation_config = { + "temperature": temperature, + "max_output_tokens": max_new_tokens, + "top_p": top_p, + } + params = { + "model": model_name, + "prompt": messages, + } + params.update(generation_config) + logger.info(f"==== request ====\n{params}") + + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, + ] + + history = [] + system_prompt = None + for message in messages[:-1]: + if message["role"] == "system": + system_prompt = message["content"] + continue + history.append({"role": message["role"], "parts": message["content"]}) + + model = genai.GenerativeModel( + model_name=model_name, + system_instruction=system_prompt, + generation_config=generation_config, + safety_settings=safety_settings, + ) + convo = model.start_chat(history=history) + + if use_stream: + response = convo.send_message(messages[-1]["content"], stream=True) + try: + text = "" + for chunk in response: + text += chunk.candidates[0].content.parts[0].text + data = { + "text": text, + "error_code": 0, + } + yield data + except Exception as e: + logger.error(f"==== error ====\n{e}") + reason = chunk.candidates + yield { + "text": f"**API REQUEST ERROR** Reason: {reason}.", + "error_code": 1, + } + else: + try: + response = convo.send_message(messages[-1]["content"], stream=False) + text = response.candidates[0].content.parts[0].text + pos = 0 + while pos < len(text): + # simulate token streaming + pos += 3 + time.sleep(0.001) + data = { + "text": text[:pos], + "error_code": 0, + } + yield data + except Exception as e: + logger.error(f"==== error ====\n{e}") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}.", + "error_code": 1, + } + + +def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None): + del top_p # not supported + del temperature # not supported + + if api_key is None: + api_key = os.environ["BARD_API_KEY"] + + # convert conv to conv_bard + conv_bard = [] + for turn in conv: + if turn["role"] == "user": + conv_bard.append({"author": "0", "content": turn["content"]}) + elif turn["role"] == "assistant": + conv_bard.append({"author": "1", "content": turn["content"]}) + else: + raise ValueError(f"Unsupported role: {turn['role']}") + + params = { + "model": model_name, + "prompt": conv_bard, + } + logger.info(f"==== request ====\n{params}") + + try: + res = requests.post( + f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}", + json={ + "prompt": { + "messages": conv_bard, + }, + }, + timeout=60, + ) + except Exception as e: + logger.error(f"==== error ====\n{e}") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}.", + "error_code": 1, + } + + if res.status_code != 200: + logger.error(f"==== error ==== ({res.status_code}): {res.text}") + yield { + "text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.", + "error_code": 1, + } + + response_json = res.json() + if "candidates" not in response_json: + logger.error(f"==== error ==== response blocked: {response_json}") + reason = response_json["filters"][0]["reason"] + yield { + "text": f"**API REQUEST ERROR** Reason: {reason}.", + "error_code": 1, + } + + response = response_json["candidates"][0]["content"] + pos = 0 + while pos < len(response): + # simulate token streaming + pos += 1 + time.sleep(0.001) + data = { + "text": response[:pos], + "error_code": 0, + } + yield data + + +def ai2_api_stream_iter( + model_name, + model_id, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + api_base=None, +): + # get keys and needed values + ai2_key = api_key or os.environ.get("AI2_API_KEY") + api_base = api_base or "https://inferd.allen.ai/api/v1/infer" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling: + # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157 + if temperature == 0.0 and top_p < 1.0: + raise ValueError("top_p must be 1 when temperature is 0.0") + + res = requests.post( + api_base, + stream=True, + headers={"Authorization": f"Bearer {ai2_key}"}, + json={ + "model_id": model_id, + # This input format is specific to the Tulu2 model. Other models + # may require different input formats. See the model's schema + # documentation on InferD for more information. + "input": { + "messages": messages, + "opts": { + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + "logprobs": 1, # increase for more choices + }, + }, + }, + timeout=5, + ) + + if res.status_code != 200: + logger.error(f"unexpected response ({res.status_code}): {res.text}") + raise ValueError("unexpected response from InferD", res) + + text = "" + for line in res.iter_lines(): + if line: + part = json.loads(line) + if "result" in part and "output" in part["result"]: + for t in part["result"]["output"]["text"]: + text += t + else: + logger.error(f"unexpected part: {part}") + raise ValueError("empty result in InferD response") + + data = { + "text": text, + "error_code": 0, + } + yield data + + +def mistral_api_stream_iter( + model_name, messages, temperature, top_p, max_new_tokens, api_key=None +): + from mistralai.client import MistralClient + from mistralai.models.chat_completion import ChatMessage + + if api_key is None: + api_key = os.environ["MISTRAL_API_KEY"] + + client = MistralClient(api_key=api_key, timeout=5) + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + new_messages = [ + ChatMessage(role=message["role"], content=message["content"]) + for message in messages + ] + + res = client.chat_stream( + model=model_name, + temperature=temperature, + messages=new_messages, + max_tokens=max_new_tokens, + top_p=top_p, + ) + + text = "" + for chunk in res: + if chunk.choices[0].delta.content is not None: + text += chunk.choices[0].delta.content + data = { + "text": text, + "error_code": 0, + } + yield data + + +def nvidia_api_stream_iter( + model_name, messages, temp, top_p, max_tokens, api_base, api_key=None +): + model_2_api = {} + api_base += model_2_api[model_name] + + api_key = api_key or os.environ["NVIDIA_API_KEY"] + headers = { + "Authorization": f"Bearer {api_key}", + "accept": "text/event-stream", + "content-type": "application/json", + } + # nvidia api does not accept 0 temperature + if temp == 0.0: + temp = 0.000001 + + payload = { + "model": model_name, + "messages": messages, + "temperature": temp, + "top_p": top_p, + "max_tokens": max_tokens, + "seed": 42, + "stream": True, + } + logger.info(f"==== request ====\n{payload}") + + # payload.pop("model") + + # try 3 times + for i in range(3): + try: + response = requests.post( + api_base, headers=headers, json=payload, stream=True, timeout=3 + ) + break + except Exception as e: + logger.error(f"==== error ====\n{e}") + if i == 2: + yield { + "text": f"**API REQUEST ERROR** Reason: API timeout. please try again later.", + "error_code": 1, + } + return + + text = "" + for line in response.iter_lines(): + if line: + data = line.decode("utf-8") + if data.endswith("[DONE]"): + break + data = json.loads(data[6:])["choices"][0]["delta"]["content"] + text += data + yield {"text": text, "error_code": 0} + + +def yandexgpt_api_stream_iter( + model_name, messages, temperature, max_tokens, api_base, api_key, folder_id +): + api_key = api_key or os.environ["YANDEXGPT_API_KEY"] + headers = { + "Authorization": f"Api-Key {api_key}", + "content-type": "application/json", + } + + payload = { + "modelUri": f"gpt://{folder_id}/{model_name}", + "completionOptions": { + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + }, + "messages": messages, + } + logger.info(f"==== request ====\n{payload}") + + # https://llm.api.cloud.yandex.net/foundationModels/v1/completion + response = requests.post( + api_base, headers=headers, json=payload, stream=True, timeout=60 + ) + text = "" + for line in response.iter_lines(): + if line: + data = json.loads(line.decode("utf-8")) + data = data["result"] + top_alternative = data["alternatives"][0] + text = top_alternative["message"]["text"] + yield {"text": text, "error_code": 0} + + status = top_alternative["status"] + if status in ( + "ALTERNATIVE_STATUS_FINAL", + "ALTERNATIVE_STATUS_TRUNCATED_FINAL", + ): + break + + +def cohere_api_stream_iter( + client_name: str, + model_id: str, + messages: list, + temperature: Optional[ + float + ] = None, # The SDK or API handles None for all parameters following + top_p: Optional[float] = None, + max_new_tokens: Optional[int] = None, + api_key: Optional[str] = None, # default is env var CO_API_KEY + api_base: Optional[str] = None, +): + import cohere + + OPENAI_TO_COHERE_ROLE_MAP = { + "user": "User", + "assistant": "Chatbot", + "system": "System", + } + + client = cohere.Client( + api_key=api_key, + base_url=api_base, + client_name=client_name, + ) + + # prepare and log requests + chat_history = [ + dict( + role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]], message=message["content"] + ) + for message in messages[:-1] + ] + actual_prompt = messages[-1]["content"] + + gen_params = { + "model": model_id, + "messages": messages, + "chat_history": chat_history, + "prompt": actual_prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # make request and stream response + res = client.chat_stream( + message=actual_prompt, + chat_history=chat_history, + model=model_id, + temperature=temperature, + max_tokens=max_new_tokens, + p=top_p, + ) + try: + text = "" + for streaming_item in res: + if streaming_item.event_type == "text-generation": + text += streaming_item.text + yield {"text": text, "error_code": 0} + except cohere.core.ApiError as e: + logger.error(f"==== error from cohere api: {e} ====") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}", + "error_code": 1, + } + + +def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): + import vertexai + from vertexai import generative_models + from vertexai.generative_models import ( + GenerationConfig, + GenerativeModel, + Image, + ) + + project_id = os.environ.get("GCP_PROJECT_ID", None) + location = os.environ.get("GCP_LOCATION", None) + vertexai.init(project=project_id, location=location) + + text_messages = [] + for message in messages: + if type(message) == str: + text_messages.append(message) + + gen_params = { + "model": model_name, + "prompt": text_messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + safety_settings = [ + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + generative_models.SafetySetting( + category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=generative_models.HarmBlockThreshold.BLOCK_NONE, + ), + ] + generator = GenerativeModel(model_name).generate_content( + messages, + stream=True, + generation_config=GenerationConfig( + top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature + ), + safety_settings=safety_settings, + ) + + ret = "" + for chunk in generator: + # NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129 + ret += chunk.candidates[0].content.parts[0]._raw_part.text + # ret += chunk.text + data = { + "text": ret, + "error_code": 0, + } + yield data + + +def reka_api_stream_iter( + model_name: str, + messages: list, + temperature: Optional[ + float + ] = None, # The SDK or API handles None for all parameters following + top_p: Optional[float] = None, + max_new_tokens: Optional[int] = None, + api_key: Optional[str] = None, # default is env var CO_API_KEY + api_base: Optional[str] = None, +): + api_key = api_key or os.environ["REKA_API_KEY"] + + use_search_engine = False + if "-online" in model_name: + model_name = model_name.replace("-online", "") + use_search_engine = True + request = { + "model_name": model_name, + "conversation_history": messages, + "temperature": temperature, + "request_output_len": max_new_tokens, + "runtime_top_p": top_p, + "stream": True, + "use_search_engine": use_search_engine, + } + + # Make requests for logging + text_messages = [] + for message in messages: + text_messages.append({"type": message["type"], "text": message["text"]}) + logged_request = dict(request) + logged_request["conversation_history"] = text_messages + + logger.info(f"==== request ====\n{logged_request}") + + response = requests.post( + api_base, + stream=True, + json=request, + headers={ + "X-Api-Key": api_key, + }, + ) + + if response.status_code != 200: + error_message = response.text + logger.error(f"==== error from reka api: {error_message} ====") + yield { + "text": f"**API REQUEST ERROR** Reason: {error_message}", + "error_code": 1, + } + return + + for line in response.iter_lines(): + line = line.decode("utf8") + if not line.startswith("data: "): + continue + gen = json.loads(line[6:]) + yield {"text": gen["text"], "error_code": 0} diff --git a/fastchat/serve/base_model_worker.py b/fastchat/serve/base_model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe322990f1e9d7a14cb45afbc16e5574604a766 --- /dev/null +++ b/fastchat/serve/base_model_worker.py @@ -0,0 +1,241 @@ +import asyncio +import threading +import time +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.conversation import Conversation +from fastchat.utils import pretty_print_semaphore, build_logger + + +worker = None +logger = None + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + multimodal: bool = False, + ): + global logger, worker + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + self.conv = self.make_conv_template(conv_template, model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.multimodal = multimodal + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + if logger is None: + logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") + if worker is None: + worker = self + + def make_conv_template( + self, + conv_template: str = None, + model_path: str = None, + ) -> Conversation: + """ + can be overrided to costomize the conversation template for different model workers. + """ + from fastchat.conversation import get_conv_template + from fastchat.model.model_adapter import get_conversation_template + + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + return conv + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + "multimodal": self.multimodal, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if self.semaphore is None: + return 0 + else: + sempahore_value = ( + self.semaphore._value + if self.semaphore._value is not None + else self.limit_worker_concurrency + ) + waiter_count = ( + 0 if self.semaphore._waiters is None else len(self.semaphore._waiters) + ) + return self.limit_worker_concurrency - sempahore_value + waiter_count + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + def generate_stream_gate(self, params): + raise NotImplementedError + + def generate_gate(self, params): + raise NotImplementedError + + def get_embeddings(self, params): + raise NotImplementedError + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await asyncio.to_thread(worker.generate_gate, params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} diff --git a/fastchat/serve/call_monitor.py b/fastchat/serve/call_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..c38067576ba3ac213ca72a627c41b05aed0b69a5 --- /dev/null +++ b/fastchat/serve/call_monitor.py @@ -0,0 +1,223 @@ +import json +import os +import glob +import time + +from fastapi import FastAPI +import hashlib +import asyncio + +REFRESH_INTERVAL_SEC = 300 +LOG_DIR_LIST = [] +# LOG_DIR = "/home/vicuna/tmp/test_env" + + +class Monitor: + """Monitor the number of calls to each model.""" + + def __init__(self, log_dir_list: list): + self.log_dir_list = log_dir_list + self.model_call = {} + self.user_call = {} + self.model_call_limit_global = { + "gpt-4-1106-preview": 100, + "gpt-4-0125-preview": 100, + } + self.model_call_day_limit_per_user = { + "gpt-4-1106-preview": 5, + "gpt-4-0125-preview": 5, + } + + async def update_stats(self, num_file=1) -> None: + while True: + # find the latest num_file log under log_dir + json_files = [] + for log_dir in self.log_dir_list: + json_files_per_server = glob.glob(os.path.join(log_dir, "*.json")) + json_files_per_server.sort(key=os.path.getctime, reverse=True) + json_files += json_files_per_server[:num_file] + model_call = {} + user_call = {} + for json_file in json_files: + for line in open(json_file, "r", encoding="utf-8"): + obj = json.loads(line) + if obj["type"] != "chat": + continue + if obj["model"] not in model_call: + model_call[obj["model"]] = [] + model_call[obj["model"]].append( + {"tstamp": obj["tstamp"], "user_id": obj["ip"]} + ) + if obj["ip"] not in user_call: + user_call[obj["ip"]] = [] + user_call[obj["ip"]].append( + {"tstamp": obj["tstamp"], "model": obj["model"]} + ) + + self.model_call = model_call + self.model_call_stats_hour = self.get_model_call_stats(top_k=None) + self.model_call_stats_day = self.get_model_call_stats( + top_k=None, most_recent_min=24 * 60 + ) + + self.user_call = user_call + self.user_call_stats_hour = self.get_user_call_stats(top_k=None) + self.user_call_stats_day = self.get_user_call_stats( + top_k=None, most_recent_min=24 * 60 + ) + await asyncio.sleep(REFRESH_INTERVAL_SEC) + + def get_model_call_limit(self, model: str) -> int: + if model not in self.model_call_limit_global: + return -1 + return self.model_call_limit_global[model] + + def update_model_call_limit(self, model: str, limit: int) -> bool: + if model not in self.model_call_limit_global: + return False + self.model_call_limit_global[model] = limit + return True + + def is_model_limit_reached(self, model: str) -> bool: + if model not in self.model_call_limit_global: + return False + if model not in self.model_call_stats_hour: + return False + # check if the model call limit is reached + if self.model_call_stats_hour[model] >= self.model_call_limit_global[model]: + return True + return False + + def is_user_limit_reached(self, model: str, user_id: str) -> bool: + if model not in self.model_call_day_limit_per_user: + return False + if user_id not in self.user_call_stats_day: + return False + if model not in self.user_call_stats_day[user_id]["call_dict"]: + return False + # check if the user call limit is reached + if ( + self.user_call_stats_day[user_id]["call_dict"][model] + >= self.model_call_day_limit_per_user[model] + ): + return True + return False + + def get_model_call_stats( + self, target_model=None, most_recent_min: int = 60, top_k: int = 20 + ) -> dict: + model_call_stats = {} + for model, reqs in self.model_call.items(): + if target_model is not None and model != target_model: + continue + model_call = [] + for req in reqs: + if req["tstamp"] < time.time() - most_recent_min * 60: + continue + model_call.append(req["tstamp"]) + model_call_stats[model] = len(model_call) + if top_k is not None: + top_k_model = sorted( + model_call_stats, key=lambda x: model_call_stats[x], reverse=True + )[:top_k] + model_call_stats = {model: model_call_stats[model] for model in top_k_model} + return model_call_stats + + def get_user_call_stats( + self, target_model=None, most_recent_min: int = 60, top_k: int = 20 + ) -> dict: + user_call_stats = {} + for user_id, reqs in self.user_call.items(): + user_model_call = {"call_dict": {}} + for req in reqs: + if req["tstamp"] < time.time() - most_recent_min * 60: + continue + if target_model is not None and req["model"] != target_model: + continue + if req["model"] not in user_model_call["call_dict"]: + user_model_call["call_dict"][req["model"]] = 0 + user_model_call["call_dict"][req["model"]] += 1 + + user_model_call["total_calls"] = sum(user_model_call["call_dict"].values()) + if user_model_call["total_calls"] > 0: + user_call_stats[user_id] = user_model_call + if top_k is not None: + top_k_user = sorted( + user_call_stats, + key=lambda x: user_call_stats[x]["total_calls"], + reverse=True, + )[:top_k] + user_call_stats = { + user_id: user_call_stats[user_id] for user_id in top_k_user + } + return user_call_stats + + def get_num_users(self, most_recent_min: int = 60) -> int: + user_call_stats = self.get_user_call_stats( + most_recent_min=most_recent_min, top_k=None + ) + return len(user_call_stats) + + +monitor = Monitor(log_dir_list=LOG_DIR_LIST) +app = FastAPI() + + +@app.on_event("startup") +async def app_startup(): + asyncio.create_task(monitor.update_stats(2)) + + +@app.get("/get_model_call_limit/{model}") +async def get_model_call_limit(model: str): + return {"model_call_limit": {model: monitor.get_model_call_limit(model)}} + + +@app.get("/update_model_call_limit/{model}/{limit}") +async def update_model_call_limit(model: str, limit: int): + if not monitor.update_model_call_limit(model, limit): + return {"success": False} + return {"success": True} + + +@app.get("/is_limit_reached") +async def is_limit_reached(model: str, user_id: str): + if monitor.is_model_limit_reached(model): + return { + "is_limit_reached": True, + "reason": f"MODEL_HOURLY_LIMIT ({model}): {monitor.get_model_call_limit(model)}", + } + if monitor.is_user_limit_reached(model, user_id): + return { + "is_limit_reached": True, + "reason": f"USER_DAILY_LIMIT ({model}): {monitor.model_call_day_limit_per_user[model]}", + } + return {"is_limit_reached": False} + + +@app.get("/get_num_users_hr") +async def get_num_users(): + return {"num_users": len(monitor.user_call_stats_hour)} + + +@app.get("/get_num_users_day") +async def get_num_users_day(): + return {"num_users": len(monitor.user_call_stats_day)} + + +@app.get("/get_user_call_stats") +async def get_user_call_stats( + model: str = None, most_recent_min: int = 60, top_k: int = None +): + return { + "user_call_stats": monitor.get_user_call_stats(model, most_recent_min, top_k) + } + + +@app.get("/get_model_call_stats") +async def get_model_call_stats( + model: str = None, most_recent_min: int = 60, top_k: int = None +): + return { + "model_call_stats": monitor.get_model_call_stats(model, most_recent_min, top_k) + } diff --git a/fastchat/serve/cli.py b/fastchat/serve/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..78f7f51b1b18b7a10f3fab937f1475067d3e5ecf --- /dev/null +++ b/fastchat/serve/cli.py @@ -0,0 +1,304 @@ +""" +Chat with a model with command line interface. + +Usage: +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 + +Other commands: +- Type "!!exit" or an empty line to exit. +- Type "!!reset" to start a new conversation. +- Type "!!remove" to remove the last prompt. +- Type "!!regen" to regenerate the last message. +- Type "!!save " to save the conversation history to a json file. +- Type "!!load " to load a conversation history from a json file. +""" +import argparse +import os +import re +import sys + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from prompt_toolkit.key_binding import KeyBindings +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +import torch + +from fastchat.model.model_adapter import add_model_args +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.inference import ChatIO, chat_loop +from fastchat.utils import str_to_torch_dtype + + +class SimpleChatIO(ChatIO): + def __init__(self, multiline: bool = False): + self._multiline = multiline + + def prompt_for_input(self, role) -> str: + if not self._multiline: + return input(f"{role}: ") + + prompt_data = [] + line = input(f"{role} [ctrl-d/z on empty line to end]: ") + while True: + prompt_data.append(line.strip()) + try: + line = input() + except EOFError as e: + break + return "\n".join(prompt_data) + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +class RichChatIO(ChatIO): + bindings = KeyBindings() + + @bindings.add("escape", "enter") + def _(event): + event.app.current_buffer.newline() + + def __init__(self, multiline: bool = False, mouse: bool = False): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter( + words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], + pattern=re.compile("$"), + ) + self._console = Console() + self._multiline = multiline + self._mouse = mouse + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + # TODO(suquark): multiline input has some issues. fix it later. + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + mouse_support=self._mouse, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=self.bindings if self._multiline else None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str): + self._console.print(f"[bold]{role.replace('/', '|')}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # TODO(suquark): the console flickers when there is a code block + # above it. We need to cut off "live" when a code block is done. + + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for outputs in output_stream: + if not outputs: + continue + text = outputs["text"] + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def print_output(self, text: str): + self.stream_output([{"text": text}]) + + +class ProgrammaticChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + contents = "" + # `end_sequence` signals the end of a message. It is unlikely to occur in + # message content. + end_sequence = " __END_OF_A_MESSAGE_47582648__\n" + len_end = len(end_sequence) + while True: + if len(contents) >= len_end: + last_chars = contents[-len_end:] + if last_chars == end_sequence: + break + try: + char = sys.stdin.read(1) + contents = contents + char + except EOFError: + continue + contents = contents[:-len_end] + print(f"[!OP:{role}]: {contents}", flush=True) + return contents + + def prompt_for_output(self, role: str): + print(f"[!OP:{role}]: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +def main(args): + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + if args.style == "simple": + chatio = SimpleChatIO(args.multiline) + elif args.style == "rich": + chatio = RichChatIO(args.multiline, args.mouse) + elif args.style == "programmatic": + chatio = ProgrammaticChatIO() + else: + raise ValueError(f"Invalid style for console: {args.style}") + try: + chat_loop( + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + str_to_torch_dtype(args.dtype), + args.load_8bit, + args.cpu_offloading, + args.conv_template, + args.conv_system_msg, + args.temperature, + args.repetition_penalty, + args.max_new_tokens, + chatio, + gptq_config=GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ), + awq_config=AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ), + exllama_config=exllama_config, + xft_config=xft_config, + revision=args.revision, + judge_sent_end=args.judge_sent_end, + debug=args.debug, + history=not args.no_history, + ) + except KeyboardInterrupt: + print("exit...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--conv-system-msg", type=str, default=None, help="Conversation system message." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--no-history", action="store_true") + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--multiline", + action="store_true", + help="Enable multiline input. Use ESC+Enter for newline.", + ) + parser.add_argument( + "--mouse", + action="store_true", + help="[Rich Style]: Enable mouse support for cursor positioning.", + ) + parser.add_argument( + "--judge-sent-end", + action="store_true", + help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + args = parser.parse_args() + main(args) diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..42d928403090d501fb9bdfa608b77bc7d9e15c31 --- /dev/null +++ b/fastchat/serve/controller.py @@ -0,0 +1,389 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import os +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from fastchat.constants import ( + CONTROLLER_HEART_BEAT_EXPIRATION, + WORKER_API_TIMEOUT, + ErrorCode, + SERVER_ERROR_MSG, +) +from fastchat.utils import build_logger + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + multimodal: bool + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stale_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,) + ) + self.heart_beat_thread.start() + + def register_worker( + self, + worker_name: str, + check_heart_beat: bool, + worker_status: dict, + multimodal: bool, + ): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + multimodal, + ) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker( + w_name, w_info.check_heart_beat, None, w_info.multimodal + ): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def list_multimodal_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + if w_info.multimodal: + model_names.update(w_info.model_names) + + return list(model_names) + + def list_language_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + if not w_info.multimodal: + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stale_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def handle_no_worker(self, params): + logger.info(f"no worker: {params['model']}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_NO_WORKER, + } + return json.dumps(ret).encode() + b"\0" + + def handle_worker_timeout(self, worker_address): + logger.info(f"worker timeout: {worker_address}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, + } + return json.dumps(ret).encode() + b"\0" + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + model_names = sorted(list(model_names)) + return { + "model_names": model_names, + "speed": speed, + "queue_length": queue_length, + } + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + yield self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + yield self.handle_worker_timeout(worker_addr) + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], + data["check_heart_beat"], + data.get("worker_status", None), + data.get("multimodal", False), + ) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/list_multimodal_models") +async def list_multimodal_models(): + models = controller.list_multimodal_models() + return {"models": models} + + +@app.post("/list_language_models") +async def list_language_models(): + models = controller.list_language_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +@app.get("/test_connection") +async def worker_api_get_status(request: Request): + return "success" + + +def create_controller(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + return args, controller + + +if __name__ == "__main__": + args, controller = create_controller() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/dashinfer_worker.py b/fastchat/serve/dashinfer_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..261c57ad5219f1cb6fbc6f415eba439e1f9fa7da --- /dev/null +++ b/fastchat/serve/dashinfer_worker.py @@ -0,0 +1,326 @@ +""" +A model worker that executes the model based on dash-infer. + +See documentations at docs/dashinfer_integration.md +""" + +import argparse +import asyncio +import copy +import json +import os +import subprocess +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from dashinfer.helper import EngineHelper, ConfigManager + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import build_logger, get_context_length, is_partial_stop + + +app = FastAPI() + + +def download_model(model_id, revision): + source = "huggingface" + if os.environ.get("FASTCHAT_USE_MODELSCOPE", "False").lower() == "true": + source = "modelscope" + + logger.info(f"Downloading model {model_id} (revision: {revision}) from {source}") + if source == "modelscope": + from modelscope import snapshot_download + + model_dir = snapshot_download(model_id, revision=revision) + elif source == "huggingface": + from huggingface_hub import snapshot_download + + model_dir = snapshot_download(repo_id=model_id) + else: + raise ValueError("Unknown source") + + logger.info(f"Save model to path {model_dir}") + + return model_dir + + +class DashInferWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + revision: str, + no_register: bool, + config: json, + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: dash-infer worker..." + ) + # check if model_path is existed at local path + if not os.path.exists(model_path): + model_path = download_model(model_path, revision) + engine_helper = EngineHelper(config) + engine_helper.init_tokenizer(model_path) + engine_helper.convert_model(model_path) + engine_helper.init_engine() + + self.context_len = engine_helper.engine_config["engine_max_length"] + self.tokenizer = engine_helper.tokenizer + self.engine_helper = engine_helper + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + temperature = params.get("temperature") + top_k = params.get("top_k") + top_p = params.get("top_p") + repetition_penalty = params.get("repetition_penalty") + presence_penalty = params.get("presence_penalty") + max_new_tokens = params.get("max_new_tokens") + stop_token_ids = params.get("stop_token_ids") or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + seed = params.get("seed") + echo = params.get("echo", True) + logprobs = params.get("logprobs") + # not supported parameters + frequency_penalty = params.get("frequency_penalty") + stop = params.get("stop") + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + gen_cfg = copy.deepcopy(self.engine_helper.default_gen_cfg) or dict() + if temperature is not None: + gen_cfg["temperature"] = float(temperature) + if top_k is not None: + dashinfer_style_top_k = 0 if int(top_k) == -1 else int(top_k) + gen_cfg["top_k"] = dashinfer_style_top_k + if top_p is not None: + gen_cfg["top_p"] = float(top_p) + if repetition_penalty is not None: + gen_cfg["repetition_penalty"] = float(repetition_penalty) + if presence_penalty is not None: + gen_cfg["presence_penalty"] = float(presence_penalty) + if len(stop_token_ids) != 0: + dashinfer_style_stop_token_ids = [[id] for id in set(stop_token_ids)] + logger.info( + f"dashinfer_style_stop_token_ids = {dashinfer_style_stop_token_ids}" + ) + gen_cfg["stop_words_ids"] = dashinfer_style_stop_token_ids + if seed is not None: + gen_cfg["seed"] = int(seed) + if logprobs is not None: + gen_cfg["logprobs"] = True + gen_cfg["top_logprobs"] = int(logprobs) + if frequency_penalty is not None: + logger.warning( + "dashinfer worker does not support `frequency_penalty` parameter" + ) + if stop is not None: + logger.warning("dashinfer worker does not support `stop` parameter") + if use_beam_search == True: + logger.warning( + "dashinfer worker does not support `use_beam_search` parameter" + ) + if best_of is not None: + logger.warning("dashinfer worker does not support `best_of` parameter") + + logger.info( + f"dashinfer engine helper creates request with context: {context}, gen_cfg: {gen_cfg}" + ) + + request_list = self.engine_helper.create_request([context], gen_cfg=[gen_cfg]) + + engine_req = request_list[0] + + # check if prompt tokens exceed the max_tokens + max_tokens = ( + gen_cfg["max_length"] + if max_new_tokens is None + else engine_req.in_tokens_len + max_new_tokens + ) + if engine_req.in_tokens_len > max_tokens: + ret = { + "text": f"This model's maximum generated tokens include context are {max_tokens}, However, your context resulted in {engine_req.in_tokens_len} tokens", + "error_code": ErrorCode.CONTEXT_OVERFLOW, + } + yield json.dumps(ret).encode() + b"\0" + else: + gen_cfg["max_length"] = int(max_tokens) + logger.info( + f"dashinfer is going to process one request in stream mode: {engine_req}" + ) + results_generator = self.engine_helper.process_one_request_stream( + engine_req + ) + + try: + for generate_text in results_generator: + if echo: + output_text = context + generate_text + else: + output_text = generate_text + prompt_tokens = engine_req.in_tokens_len + completion_tokens = engine_req.out_tokens_len + ret = { + "text": output_text, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + yield (json.dumps(ret) + "\0").encode() + except Exception as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await worker.generate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="qwen/Qwen-7B-Chat") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "config_file", + metavar="config-file", + type=str, + default="config_qwen_v10_7b.json", + help="A model config file which dash-inferread", + ) + + args = parser.parse_args() + config = ConfigManager.get_config_from_json(args.config_file) + + cmd = f"pip show dashinfer | grep 'Location' | cut -d ' ' -f 2" + package_location = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True + ) + package_location = package_location.stdout.strip() + os.environ["AS_DAEMON_PATH"] = package_location + "/dashinfer/allspark/bin" + os.environ["AS_NUMA_NUM"] = str(len(config["device_ids"])) + os.environ["AS_NUMA_OFFSET"] = str(config["device_ids"][0]) + worker = DashInferWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.revision, + args.no_register, + config, + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/example_images/distracted.jpg b/fastchat/serve/example_images/distracted.jpg new file mode 100644 index 0000000000000000000000000000000000000000..382c888a0305296d7307ce061d527e1c5e01aca3 Binary files /dev/null and b/fastchat/serve/example_images/distracted.jpg differ diff --git a/fastchat/serve/example_images/fridge.jpg b/fastchat/serve/example_images/fridge.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ed943e8be506b2a0da66bd1cddf39d2dcbdb5fb Binary files /dev/null and b/fastchat/serve/example_images/fridge.jpg differ diff --git a/fastchat/serve/gateway/README.md b/fastchat/serve/gateway/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b3afaf171bc38b232b68609585244c9e76489da7 --- /dev/null +++ b/fastchat/serve/gateway/README.md @@ -0,0 +1,57 @@ +# fastchat Nginx Gateway + +## Purpose of the Gateway + +The Nginx gateway serves the following purposes: + +1. Protects Gradio servers by acting as a firewall. +2. Facilitates dynamic mounting and unmounting of Gradio servers. +3. Provides load balancing for Gradio servers. +4. Offers additional security features, such as total connection limit. +5. Reduces attack surface by requiring only a single public port to be exposed for serving. + +## Deployment and Updating of the Gateway + +### Installing Nginx + +On Debian-based distributions (e.g., Ubuntu): + +```bash +sudo apt update +sudo apt install nginx +``` +On Red Hat-based distributions (e.g., CentOS, Fedora): + +```bash +sudo yum install epel-release +sudo yum install nginx +``` + +### Deployment + +Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). + +Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. + +Modify `upstream websocket` to configure Gradio servers behind the gateway. + +Lastly, update Nginx. + + +### HTTPS Deployment with a Public Domain URL + +Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. + +Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. + +If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. + +### Updating + +Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: + +```bash +sudo nginx -t # check `/etc/nginx/nginx.conf` +sudo systemctl reload nginx # restart Nginx service to load the new config +sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). +``` diff --git a/fastchat/serve/gateway/nginx.conf b/fastchat/serve/gateway/nginx.conf new file mode 100644 index 0000000000000000000000000000000000000000..b88ca8c50772421fca91f33ff77ef75f4d23ad4d --- /dev/null +++ b/fastchat/serve/gateway/nginx.conf @@ -0,0 +1,97 @@ +user www-data; +worker_processes auto; +pid /run/nginx.pid; +include /etc/nginx/modules-enabled/*.conf; + +events { + worker_connections 1024; # maximum number of connections that a worker process can handle concurrently + # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle + +} + +http { + ## + # Basic Settings + ## + + sendfile on; # enable sendfile for performance optimization + tcp_nopush on; # enable TCP no-pushing + tcp_nodelay on; # enable TCP no-delay + keepalive_timeout 65; # sets the timeout for keep-alive connections + types_hash_max_size 2048; # maximum size of the types hash table + # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security + + # server_names_hash_bucket_size 64; + # server_name_in_redirect off; + + include /etc/nginx/mime.types; # include MIME types file + default_type application/octet-stream; # default MIME type for unknown file types + + ## + # SSL Settings + ## + + ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use + ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers + + ## + # Logging Settings + ## + + access_log /var/log/nginx/access.log; # path to access log file + error_log /var/log/nginx/error.log; # path to error log file + + ## + # Gzip Settings + ## + gzip on; # enable Gzip compression + + ## + # Virtual Host Configs + ## + + include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory + include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files + + # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ + map $http_upgrade $connection_upgrade { + default upgrade; + '' close; + } + + upstream websocket { + ip_hash; # load balancing by IP to guarantee session persistence + server localhost:7860; # The port should be the gradio web server port + # server localhost:7861; # extra gradio server if more than one + } + + limit_conn_status 429; + limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP + limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server + + server { + listen 443 ssl; # the listening port of our server + ssl_certificate [PATH_TO_SSL_CERT]; + ssl_certificate_key [PATH_TO_PRIVATE_KEY]; + server_name chat.lmsys.org; # replace the url with your own domain url + limit_conn perserver 1024; # connections per server + location / { + proxy_pass http://websocket; # proxy all requests to the defined upstream server + limit_conn perip 5; # connections per IP + proxy_set_header Host $host; # set the Host header for the upstream server + proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header + proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication + } + } + + # the following block routes all HTTP traffic to HTTPS via nginx + server { + listen 80; + server_name chat.lmsys.org; + return 301 https://chat.lmsys.org$request_uri; + } + +} diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py new file mode 100644 index 0000000000000000000000000000000000000000..e164db04f95e2738cd1a74b37e2bc9b20aaa8f67 --- /dev/null +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -0,0 +1,623 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + BLIND_MODE_INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SURVEY_LINK, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + enable_text, + disable_text, + acknowledgment_md, + get_ip, + get_model_description_md, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + + +def set_global_vars_anony(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown(visible=True), + gr.Markdown(visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + gr.Info( + "🎉 Thanks for voting! Your vote shapes the leaderboard, please vote RESPONSIBLY." + ) + if ":" not in model_selectors[0]: + for i in range(5): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + # yield names + ("",) + (disable_btn,) * 4 + yield names + (disable_text,) + (disable_btn,) * 5 + time.sleep(0.1) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + # yield names + ("",) + (disable_btn,) * 4 + yield names + (disable_text,) + (disable_btn,) * 5 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [enable_text] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [""] + + [enable_btn] + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (anony). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +SAMPLING_WEIGHTS = {} + +# target model sampling weights will be boosted. +BATTLE_TARGETS = {} + +ANON_MODELS = [] + +SAMPLING_BOOST_MODELS = [] + +# outage models won't be sampled. +OUTAGE_MODELS = [] + + +def get_sample_weight(model, outage_models, sampling_weights, sampling_boost_models=[]): + if model in outage_models: + return 0 + weight = sampling_weights.get(model, 1e-10) + if model in sampling_boost_models: + weight *= 5 + return weight + + +def get_battle_pair( + models, battle_targets, outage_models, sampling_weights, sampling_boost_models +): + if len(models) == 1: + return models[0], models[0] + + model_weights = [] + for model in models: + weight = get_sample_weight( + model, outage_models, sampling_weights, sampling_boost_models + ) + model_weights.append(weight) + total_weight = np.sum(model_weights) + model_weights = model_weights / total_weight + chosen_idx = np.random.choice(len(models), p=model_weights) + chosen_model = models[chosen_idx] + # for p, w in zip(models, model_weights): + # print(p, w) + + rival_models = [] + rival_weights = [] + for model in models: + if model == chosen_model: + continue + if model in ANON_MODELS and chosen_model in ANON_MODELS: + continue + weight = get_sample_weight(model, outage_models, sampling_weights) + if ( + weight != 0 + and chosen_model in battle_targets + and model in battle_targets[chosen_model] + ): + # boost to 20% chance + weight = 0.5 * total_weight / len(battle_targets[chosen_model]) + rival_models.append(model) + rival_weights.append(weight) + # for p, w in zip(rival_models, rival_weights): + # print(p, w) + rival_weights = rival_weights / np.sum(rival_weights) + rival_idx = np.random.choice(len(rival_models), p=rival_weights) + rival_model = rival_models[rival_idx] + + swap = np.random.randint(2) + if swap == 0: + return chosen_model, rival_model + else: + return rival_model, chosen_model + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + model_left, model_right = get_battle_pair( + models, + BATTLE_TARGETS, + OUTAGE_MODELS, + SAMPLING_WEIGHTS, + SAMPLING_BOOST_MODELS, + ) + states = [ + State(model_left), + State(model_right), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + ["", None] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + # turn on moderation in battle mode + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[0].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + flagged = moderation_filter(all_conv_text, model_list, do_moderation=True) + if flagged: + logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + hint_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + hint_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + + [hint_msg] + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") + + if state0 is None or state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + apply_rate_limit=False, + use_recommended_config=True, + ) + ) + + model_tpy = [] + for i in range(num_sides): + token_per_yield = 1 + if states[i].model_name in [ + "gemini-pro", + "gemma-1.1-2b-it", + "gemma-1.1-7b-it", + "phi-3-mini-4k-instruct", + "phi-3-mini-128k-instruct", + "snowflake-arctic-instruct", + ]: + token_per_yield = 30 + elif states[i].model_name in [ + "qwen-max-0428", + "qwen1.5-110b-chat", + "llava-v1.6-34b", + ]: + token_per_yield = 7 + elif states[i].model_name in [ + "qwen2-72b-instruct", + ]: + token_per_yield = 4 + model_tpy.append(token_per_yield) + + chatbots = [None] * num_sides + iters = 0 + while True: + stop = True + iters += 1 + for i in range(num_sides): + try: + # yield fewer times if chunk size is larger + if model_tpy[i] == 1 or (iters % model_tpy[i] == 1 or iters < 3): + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def build_side_by_side_ui_anony(models): + notice_markdown = f""" +# ⚔️ LMSYS Chatbot Arena: Benchmarking LLMs in the Wild +[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena) + +{SURVEY_LINK} + +## 📣 News +- Chatbot Arena now supports images in beta. Check it out [here](https://chat.lmsys.org/?vision). + +## 📜 Rules +- Ask any question to two anonymous models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one! +- You can chat for multiple turns until you identify a winner. +- Votes won't be counted if model identities are revealed during the conversation. + +## 🏆 Chatbot Arena [Leaderboard](https://leaderboard.lmsys.org) +- We've collected **1,000,000+** human votes to compute an LLM leaderboard for 100+ models. Find out who is the 🥇LLM Champion [here](https://leaderboard.lmsys.org)! + +## 👇 Chat now! +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id="chatbot", + height=650, + show_copy_button=True, + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown( + anony_names[i], elem_id="model_selector_md" + ) + with gr.Row(): + slow_warning = gr.Markdown("") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1600, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn, send_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + + chatbots + + model_selectors + + [textbox] + + btn_list + + [slow_warning] + + [send_btn], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list + [slow_warning], + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee19b0413aa81756672f97cbcca4c5c5b42ef2e --- /dev/null +++ b/fastchat/serve/gradio_block_arena_named.py @@ -0,0 +1,501 @@ +""" +Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SURVEY_LINK, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False + + +def set_global_vars_named(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_named(models, url_params): + states = (None,) * num_sides + + model_left = models[0] if len(models) > 0 else "" + if len(models) > 1: + weights = ([8] * 4 + [4] * 8 + [1] * 64)[: len(models) - 1] + weights = weights / np.sum(weights) + model_right = np.random.choice(models[1:], p=weights) + else: + model_right = model_left + + selector_updates = ( + gr.Dropdown(choices=models, value=model_left, visible=True), + gr.Dropdown(choices=models, value=model_right, visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return states + [x.to_gradio_chatbot() for x in states] + [""] + [no_change_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (named). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i]) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + ["", None] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[1].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + flagged = moderation_filter(all_conv_text, model_list) + if flagged: + logger.info(f"violate moderation (named). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") + + if state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + model_tpy = [] + for i in range(num_sides): + token_per_yield = 1 + if states[i].model_name in [ + "gemini-pro", + "gemma-1.1-2b-it", + "gemma-1.1-7b-it", + "phi-3-mini-4k-instruct", + "phi-3-mini-128k-instruct", + "snowflake-arctic-instruct", + ]: + token_per_yield = 30 + elif states[i].model_name in [ + "qwen-max-0428", + "qwen1.5-110b-chat", + ]: + token_per_yield = 7 + elif states[i].model_name in [ + "qwen2-72b-instruct", + ]: + token_per_yield = 4 + model_tpy.append(token_per_yield) + + chatbots = [None] * num_sides + iters = 0 + while True: + stop = True + iters += 1 + for i in range(num_sides): + try: + # yield fewer times if chunk size is larger + if model_tpy[i] == 1 or (iters % model_tpy[i] == 1 or iters < 3): + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def flash_buttons(): + btn_updates = [ + [disable_btn] * 4 + [enable_btn] * 2, + [enable_btn] * 6, + ] + for i in range(4): + yield btn_updates[i % 2] + time.sleep(0.3) + + +def build_side_by_side_ui_named(models): + notice_markdown = f""" +# ⚔️ LMSYS Chatbot Arena: Benchmarking LLMs in the Wild +[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena) + +{SURVEY_LINK} + +## 📜 Rules +- Ask any question to two chosen models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one! +- You can chat for multiple turns until you identify a winner. + +## 👇 Choose two models to compare +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-named"): + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + with gr.Row(): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id=f"chatbot", + height=650, + show_copy_button=True, + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..25ff78c087bceca113fe5a6372b963fb50b774de --- /dev/null +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -0,0 +1,432 @@ +""" +The gradio demo server for chatting with a large multimodal model. + +Usage: +python3 -m fastchat.serve.controller +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf +python3 -m fastchat.serve.gradio_web_server_multi --share --vision-arena +""" + +import json +import os +import time + +import gradio as gr +from gradio.data_classes import FileData +import numpy as np + +from fastchat.constants import ( + TEXT_MODERATION_MSG, + IMAGE_MODERATION_MSG, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SURVEY_LINK, +) +from fastchat.model.model_adapter import ( + get_conversation_template, +) +from fastchat.serve.gradio_web_server import ( + get_model_description_md, + acknowledgment_md, + bot_response, + get_ip, + disable_btn, + State, + get_conv_log_filename, + get_remote_logger, +) +from fastchat.serve.vision.image import ImageFormat, Image +from fastchat.utils import ( + build_logger, + moderation_filter, + image_moderation_filter, +) + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True, visible=True) +disable_btn = gr.Button(interactive=False) +invisible_btn = gr.Button(interactive=False, visible=False) +visible_image_column = gr.Image(visible=True) +invisible_image_column = gr.Image(visible=False) +enable_multimodal = gr.MultimodalTextbox( + interactive=True, visible=True, placeholder="Enter your prompt or add image here" +) +invisible_text = gr.Textbox(visible=False, value="", interactive=False) +visible_text = gr.Textbox( + visible=True, + value="", + interactive=True, + placeholder="👉 Enter your prompt and press ENTER", +) +disable_multimodal = gr.MultimodalTextbox(visible=False, value=None, interactive=False) + + +def get_vqa_sample(): + random_sample = np.random.choice(vqa_samples) + question, path = random_sample["question"], random_sample["path"] + res = {"text": "", "files": [path]} + return (res, path) + + +def set_visible_image(textbox): + images = textbox["files"] + if len(images) == 0: + return invisible_image_column + elif len(images) > 1: + gr.Warning( + "We only support single image conversations. Please start a new round if you would like to chat using this image." + ) + + return visible_image_column + + +def set_invisible_image(): + return invisible_image_column + + +def add_image(textbox): + images = textbox["files"] + if len(images) == 0: + return None + + return images[0] + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + filename = get_conv_log_filename(state.is_vision, state.has_csam_image) + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return (None,) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + if not state.regen_support: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], None) + (disable_btn,) * 5 + + +def clear_history_example(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history_example. ip: {ip}") + state = None + return (state, [], enable_multimodal) + (disable_btn,) * 5 + + +# TODO(Chris): At some point, we would like this to be a live-reporting feature. +def report_csam_image(state, image): + pass + + +def _prepare_text_with_image(state, text, images, csam_flag): + if len(images) > 0: + if len(state.conv.get_images()) > 0: + # reset convo with new image + state.conv = get_conversation_template(state.model_name) + + text = text, [images[0]] + + return text + + +# NOTE(chris): take multiple images later on +def convert_images_to_conversation_format(images): + import base64 + + MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5 + conv_images = [] + if len(images) > 0: + conv_image = Image(url=images[0]) + conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB) + conv_images.append(conv_image) + + return conv_images + + +def moderate_input(state, text, all_conv_text, model_list, images, ip): + text_flagged = moderation_filter(all_conv_text, model_list) + # flagged = moderation_filter(text, [state.model_name]) + nsfw_flagged, csam_flagged = False, False + if len(images) > 0: + nsfw_flagged, csam_flagged = image_moderation_filter(images[0]) + + image_flagged = nsfw_flagged or csam_flagged + if text_flagged or image_flagged: + logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}") + if text_flagged and not image_flagged: + # overwrite the original text + text = TEXT_MODERATION_MSG + elif not text_flagged and image_flagged: + text = IMAGE_MODERATION_MSG + elif text_flagged and image_flagged: + text = MODERATION_MSG + + if csam_flagged: + state.has_csam_image = True + report_csam_image(state, images[0]) + + return text, image_flagged, csam_flagged + + +def add_text(state, model_selector, chat_input, request: gr.Request): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector, is_vision=True) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5 + + all_conv_text = state.conv.get_prompt() + all_conv_text = all_conv_text[-2000:] + "\nuser: " + text + + images = convert_images_to_conversation_format(images) + + text, image_flagged, csam_flag = moderate_input( + state, text, all_conv_text, [state.model_name], images, ip + ) + + if image_flagged: + logger.info(f"image flagged. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + ( + no_change_btn, + ) * 5 + + if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag) + state.conv.append_message(state.conv.roles[0], text) + state.conv.append_message(state.conv.roles[1], None) + return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + + +def build_single_vision_language_model_ui( + models, add_promotion_links=False, random_questions=None +): + promotion = ( + f""" +- [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) + +{SURVEY_LINK} + +**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.** + +Note: You can only chat with one image per conversation. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image.""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Large Vision-Language Models +{promotion} +""" + + state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", open=False + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Enter your prompt or add image here", + container=True, + render=False, + elem_id="input_box", + ) + + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + with gr.Column(scale=8): + chatbot = gr.Chatbot( + elem_id="chatbot", label="Scroll down and start chatting", height=650 + ) + + with gr.Row(): + textbox.render() + # with gr.Column(scale=1, min_width=50): + # send_btn = gr.Button(value="Send", variant="primary") + + with gr.Row(elem_id="buttons"): + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + examples = gr.Examples( + examples=[ + { + "text": "How can I prepare a delicious meal using these ingredients?", + "files": [f"{cur_dir}/example_images/fridge.jpg"], + }, + { + "text": "What might the woman on the right be thinking about?", + "files": [f"{cur_dir}/example_images/distracted.jpg"], + }, + ], + inputs=[textbox], + ) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + if add_promotion_links: + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + model_selector.change( + clear_history, None, [state, chatbot, textbox] + btn_list + ).then(set_visible_image, [textbox], [image_column]) + examples.dataset.click( + clear_history_example, None, [state, chatbot, textbox] + btn_list + ) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then(set_invisible_image, [], [image_column]).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, [state, chatbot, textbox] + btn_list + ) + + return [state, model_selector] diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py new file mode 100644 index 0000000000000000000000000000000000000000..9e54930ef06a1bc66eb2cdfb2c42455e3d66721c --- /dev/null +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -0,0 +1,643 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + TEXT_MODERATION_MSG, + IMAGE_MODERATION_MSG, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + BLIND_MODE_INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SURVEY_LINK, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + disable_text, + enable_text, +) +from fastchat.serve.gradio_block_arena_anony import ( + flash_buttons, + vote_last_response, + leftvote_last_response, + rightvote_last_response, + tievote_last_response, + bothbad_vote_last_response, + regenerate, + clear_history, + share_click, + add_text, + bot_response_multi, + set_global_vars_anony, + load_demo_side_by_side_anony, + get_sample_weight, + get_battle_pair, + SAMPLING_WEIGHTS, + BATTLE_TARGETS, + SAMPLING_BOOST_MODELS, + OUTAGE_MODELS, +) +from fastchat.serve.gradio_block_arena_vision import ( + set_invisible_image, + set_visible_image, + add_image, + moderate_input, + enable_multimodal, + _prepare_text_with_image, + convert_images_to_conversation_format, + invisible_text, + visible_text, + disable_multimodal, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, + image_moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +text_models = [] +vl_models = [] + +# TODO(chris): fix sampling weights +VISION_SAMPLING_WEIGHTS = { + "gpt-4o-2024-05-13": 4, + "gpt-4-turbo-2024-04-09": 4, + "claude-3-haiku-20240307": 4, + "claude-3-sonnet-20240229": 4, + "claude-3-5-sonnet-20240620": 4, + "claude-3-opus-20240229": 4, + "gemini-1.5-flash-api-0514": 4, + "gemini-1.5-pro-api-0514": 4, + "llava-v1.6-34b": 4, + "reka-core-20240501": 4, + "reka-flash-preview-20240611": 4, +} + +# TODO(chris): Find battle targets that make sense +VISION_BATTLE_TARGETS = {} + +# TODO(chris): Fill out models that require sampling boost +VISION_SAMPLING_BOOST_MODELS = [] + +# outage models won't be sampled. +VISION_OUTAGE_MODELS = [] + + +def get_vqa_sample(): + random_sample = np.random.choice(vqa_samples) + question, path = random_sample["question"], random_sample["path"] + res = {"text": "", "files": [path]} + return (res, path) + + +def load_demo_side_by_side_vision_anony(all_text_models, all_vl_models, url_params): + global text_models, vl_models + text_models = all_text_models + vl_models = all_vl_models + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown(visible=True), + gr.Markdown(visible=True), + ) + + return states + selector_updates + + +def clear_history_example(request: gr.Request): + logger.info(f"clear_history_example (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [enable_multimodal, invisible_text] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [enable_btn] + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image) + + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + gr.Info( + "🎉 Thanks for voting! Your vote shapes the leaderboard, please vote RESPONSIBLY." + ) + if ":" not in model_selectors[0]: + for i in range(5): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + (disable_text,) + (disable_btn,) * 4 + time.sleep(0.1) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + (disable_text,) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return ( + states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 + ) + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [enable_multimodal, invisible_text] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [enable_btn] + + [""] + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request +): + if isinstance(chat_input, dict): + text, images = chat_input["text"], chat_input["files"] + else: + text = chat_input + images = [] + + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + if len(images) > 0: + model_left, model_right = get_battle_pair( + vl_models, + VISION_BATTLE_TARGETS, + VISION_OUTAGE_MODELS, + VISION_SAMPLING_WEIGHTS, + VISION_SAMPLING_BOOST_MODELS, + ) + states = [ + State(model_left, is_vision=True), + State(model_right, is_vision=True), + ] + else: + model_left, model_right = get_battle_pair( + text_models, + BATTLE_TARGETS, + OUTAGE_MODELS, + SAMPLING_WEIGHTS, + SAMPLING_BOOST_MODELS, + ) + + states = [ + State(model_left, is_vision=False), + State(model_right, is_vision=False), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None, ""] + + [ + no_change_btn, + ] + * 7 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + + images = convert_images_to_conversation_format(images) + + text, image_flagged, csam_flag = moderate_input( + state0, text, text, model_list, images, ip + ) + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": CONVERSATION_LIMIT_MSG}, ""] + + [ + no_change_btn, + ] + * 7 + + [""] + ) + + if image_flagged: + logger.info(f"image flagged. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [ + { + "text": IMAGE_MODERATION_MSG + + " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION." + }, + "", + ] + + [no_change_btn] * 7 + + [""] + ) + + text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, images, csam_flag=csam_flag + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + hint_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + hint_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [disable_multimodal, visible_text] + + [ + disable_btn, + ] + * 7 + + [hint_msg] + ) + + +def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=None): + notice_markdown = f""" +# ⚔️ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild +[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena) + +{SURVEY_LINK} + +## 📜 Rules +- Ask any question to two anonymous models (e.g., ChatGPT, Gemini, Claude, Llama) and vote for the better one! +- You can continue chatting until you identify a winner. +- Vote won't be counted if model identity is revealed during conversation. +- **NEW** Image Support: Upload an image on your first turn to unlock the multimodal arena! Images should be less than 15MB. + +## 🏆 Chatbot Arena [Leaderboard](https://leaderboard.lmsys.org) +- We've collected **1,000,000+** human votes to compute an LLM Elo leaderboard for 100+ models. Find out who is the 🥇LLM Champion [here](https://leaderboard.lmsys.org)! + +## 👇 Chat now! +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Row(): + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + + with gr.Column(scale=5): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(text_models) + len(vl_models)} models", + open=False, + ): + model_description_md = get_model_description_md( + text_models + vl_models + ) + gr.Markdown( + model_description_md, elem_id="model_description_markdown" + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id="chatbot", + height=650, + show_copy_button=True, + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown( + anony_names[i], elem_id="model_selector_md" + ) + with gr.Row(): + slow_warning = gr.Markdown("", elem_id="notice_markdown") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + visible=False, + ) + + multimodal_textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + container=True, + placeholder="Enter your prompt or add image here", + elem_id="input_box", + ) + # send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🔮 Random Image", interactive=True) + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1800, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + + chatbots + + model_selectors + + [multimodal_textbox, textbox] + + btn_list + + ([random_btn] if random_questions else []) + + [slow_warning], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( + set_visible_image, [multimodal_textbox], [image_column] + ).then( + clear_history_example, + None, + states + chatbots + model_selectors + [multimodal_textbox, textbox] + btn_list, + ) + + multimodal_textbox.submit( + add_text, + states + model_selectors + [multimodal_textbox], + states + + chatbots + + [multimodal_textbox, textbox] + + btn_list + + ([random_btn] if random_questions else []) + + [slow_warning], + ).then(set_invisible_image, [], [image_column]).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + + chatbots + + [multimodal_textbox, textbox] + + btn_list + + ([random_btn] if random_questions else []) + + [slow_warning], + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [multimodal_textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [multimodal_textbox], [image_column]).then( + clear_history_example, + None, + states + + chatbots + + model_selectors + + [multimodal_textbox, textbox] + + btn_list + + ([random_btn] if random_questions else []), + ) + + return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py new file mode 100644 index 0000000000000000000000000000000000000000..ecca169ca2da7d29d78ac71e98f37b4a6ef75ef8 --- /dev/null +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -0,0 +1,466 @@ +""" +Multimodal Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import os +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + TEXT_MODERATION_MSG, + IMAGE_MODERATION_MSG, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SURVEY_LINK, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import ( + flash_buttons, + share_click, + bot_response_multi, +) +from fastchat.serve.gradio_block_arena_vision import ( + get_vqa_sample, + set_invisible_image, + set_visible_image, + add_image, + moderate_input, + _prepare_text_with_image, + convert_images_to_conversation_format, + enable_multimodal, +) +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_ip, + get_model_description_md, + enable_text, +) +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + moderation_filter, + image_moderation_filter, +) + + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False + + +def clear_history_example(request: gr.Request): + logger.info(f"clear_history_example (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [enable_multimodal] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + filename = get_conv_log_filename(states[0].is_vision, states[0].has_csam_image) + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return (None,) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + if state0.regen_support and state1.regen_support: + for i in range(num_sides): + states[i].conv.update_last_message(None) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [disable_btn] * 6 + ) + states[0].skip_next = True + states[1].skip_next = True + return ( + states + [x.to_gradio_chatbot() for x in states] + [None] + [no_change_btn] * 6 + ) + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [enable_multimodal] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request +): + text, images = chat_input["text"], chat_input["files"] + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i], is_vision=True) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + all_conv_text_left = states[0].conv.get_prompt() + all_conv_text_right = states[0].conv.get_prompt() + all_conv_text = ( + all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text + ) + + images = convert_images_to_conversation_format(images) + + text, image_flagged, csam_flag = moderate_input( + state0, text, all_conv_text, model_list, images, ip + ) + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": CONVERSATION_LIMIT_MSG}] + + [ + no_change_btn, + ] + * 6 + ) + + if image_flagged: + logger.info(f"image flagged. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [{"text": IMAGE_MODERATION_MSG}] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, images, csam_flag=csam_flag + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [None] + + [ + disable_btn, + ] + * 6 + ) + + +def build_side_by_side_vision_ui_named(models, random_questions=None): + notice_markdown = f""" +# ⚔️ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild +[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) + +{SURVEY_LINK} + +## 📜 Rules +- Chat with any two models side-by-side and vote! +- You can continue chatting for multiple rounds. +- Click "Clear history" to start a new round. +- You can only chat with one image per conversation. You can upload images less than 15MB. Click the "Random Example" button to chat with a random image. + +**❗️ For research purposes, we log user prompts and images, and may release this data to the public in the future. Please do not upload any confidential or personal information.** + +## 🤖 Choose two models to compare +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Row(): + with gr.Column(scale=2, visible=False) as image_column: + imagebox = gr.Image( + type="pil", + show_label=False, + interactive=False, + ) + + with gr.Column(scale=5): + with gr.Group(elem_id="share-region-anony"): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown( + model_description_md, elem_id="model_description_markdown" + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, + elem_id=f"chatbot", + height=650, + show_copy_button=True, + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Enter your prompt or add image here", + container=True, + elem_id="input_box", + ) + + with gr.Row() as button_row: + if random_questions: + global vqa_samples + with open(random_questions, "r") as f: + vqa_samples = json.load(f) + random_btn = gr.Button(value="🎲 Random Example", interactive=True) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ).then(set_visible_image, [textbox], [image_column]) + + textbox.input(add_image, [textbox], [imagebox]).then( + set_visible_image, [textbox], [image_column] + ).then(clear_history_example, None, states + chatbots + [textbox] + btn_list) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then(set_invisible_image, [], [image_column]).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + if random_questions: + random_btn.click( + get_vqa_sample, # First, get the VQA sample + [], # Pass the path to the VQA samples + [textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [textbox], [image_column]).then( + clear_history_example, None, states + chatbots + [textbox] + btn_list + ) + + return states + model_selectors diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef47b14df24dcf14599d7122f5bb3428b251428 --- /dev/null +++ b/fastchat/serve/gradio_web_server.py @@ -0,0 +1,1063 @@ +""" +The gradio demo server for chatting with a single model. +""" + +import argparse +from collections import defaultdict +import datetime +import hashlib +import json +import os +import random +import time +import uuid + +import gradio as gr +import requests + +from fastchat.constants import ( + LOGDIR, + WORKER_API_TIMEOUT, + ErrorCode, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + RATE_LIMIT_MSG, + SERVER_ERROR_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SESSION_EXPIRATION_TIME, + SURVEY_LINK, +) +from fastchat.model.model_adapter import ( + get_conversation_template, +) +from fastchat.model.model_registry import get_model_info, model_info +from fastchat.serve.api_provider import get_api_provider_stream_iter +from fastchat.serve.remote_logger import get_remote_logger +from fastchat.utils import ( + build_logger, + get_window_url_params_js, + get_window_url_params_with_tos_js, + moderation_filter, + parse_gradio_auth_creds, + load_image, +) + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "FastChat Client"} + +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True, visible=True) +disable_btn = gr.Button(interactive=False) +invisible_btn = gr.Button(interactive=False, visible=False) +enable_text = gr.Textbox( + interactive=True, visible=True, placeholder="👉 Enter your prompt and press ENTER" +) +disable_text = gr.Textbox( + interactive=False, + visible=True, + placeholder='Press "🎲 New Round" to start over👇 (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)', +) + +controller_url = None +enable_moderation = False +use_remote_storage = False + +acknowledgment_md = """ +### Terms of Service + +Users are required to agree to the following terms before using the service: + +The service is a research preview. It only provides limited safety measures and may generate offensive content. +It must not be used for any illegal, harmful, violent, racist, or sexual purposes. +Please do not upload any private information. +The service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license. + +#### Please report any bug or issue to our [Discord](https://discord.gg/HSWAKCrnFx)/arena-feedback. + +### Acknowledgment +We thank [UC Berkeley SkyLab](https://sky.cs.berkeley.edu/), [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Hyperbolic](https://hyperbolic.xyz/), [RunPod](https://runpod.io), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous [sponsorship](https://lmsys.org/donations/). + + +""" + +# JSON file format of API-based models: +# { +# "gpt-3.5-turbo": { +# "model_name": "gpt-3.5-turbo", +# "api_type": "openai", +# "api_base": "https://api.openai.com/v1", +# "api_key": "sk-******", +# "anony_only": false +# } +# } +# +# - "api_type" can be one of the following: openai, anthropic, gemini, or mistral. For custom APIs, add a new type and implement it accordingly. +# - "anony_only" indicates whether to display this model in anonymous mode only. + +api_endpoint_info = {} + + +class State: + def __init__(self, model_name, is_vision=False): + self.conv = get_conversation_template(model_name) + self.conv_id = uuid.uuid4().hex + self.skip_next = False + self.model_name = model_name + self.oai_thread_id = None + self.is_vision = is_vision + + # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. + self.has_csam_image = False + + self.regen_support = True + if "browsing" in model_name: + self.regen_support = False + self.init_system_prompt(self.conv, is_vision) + + def init_system_prompt(self, conv, is_vision): + system_prompt = conv.get_system_message(is_vision) + if len(system_prompt) == 0: + return + current_date = datetime.datetime.now().strftime("%Y-%m-%d") + system_prompt = system_prompt.replace("{{currentDateTime}}", current_date) + conv.set_system_message(system_prompt) + + def to_gradio_chatbot(self): + return self.conv.to_gradio_chatbot() + + def dict(self): + base = self.conv.dict() + base.update( + { + "conv_id": self.conv_id, + "model_name": self.model_name, + } + ) + + if self.is_vision: + base.update({"has_csam_image": self.has_csam_image}) + return base + + +def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_): + global controller_url, enable_moderation, use_remote_storage + controller_url = controller_url_ + enable_moderation = enable_moderation_ + use_remote_storage = use_remote_storage_ + + +def get_conv_log_filename(is_vision=False, has_csam_image=False): + t = datetime.datetime.now() + conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json" + if is_vision and not has_csam_image: + name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}") + elif is_vision and has_csam_image: + name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}") + else: + name = os.path.join(LOGDIR, conv_log_filename) + + return name + + +def get_model_list(controller_url, register_api_endpoint_file, vision_arena): + global api_endpoint_info + + # Add models from the controller + if controller_url: + ret = requests.post(controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + + if vision_arena: + ret = requests.post(controller_url + "/list_multimodal_models") + models = ret.json()["models"] + else: + ret = requests.post(controller_url + "/list_language_models") + models = ret.json()["models"] + else: + models = [] + + # Add models from the API providers + if register_api_endpoint_file: + api_endpoint_info = json.load(open(register_api_endpoint_file)) + for mdl, mdl_dict in api_endpoint_info.items(): + mdl_vision = mdl_dict.get("vision-arena", False) + mdl_text = mdl_dict.get("text-arena", True) + if vision_arena and mdl_vision: + models.append(mdl) + if not vision_arena and mdl_text: + models.append(mdl) + + # Remove anonymous models + models = list(set(models)) + visible_models = models.copy() + for mdl in models: + if mdl not in api_endpoint_info: + continue + mdl_dict = api_endpoint_info[mdl] + if mdl_dict["anony_only"]: + visible_models.remove(mdl) + + # Sort models and add descriptions + priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} + models.sort(key=lambda x: priority.get(x, x)) + visible_models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"All models: {models}") + logger.info(f"Visible models: {visible_models}") + return visible_models, models + + +def load_demo_single(models, url_params): + selected_model = models[0] if len(models) > 0 else "" + if "model" in url_params: + model = url_params["model"] + if model in models: + selected_model = model + + dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) + state = None + return state, dropdown_update + + +def load_demo(url_params, request: gr.Request): + global models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + + if args.model_list_mode == "reload": + models, all_models = get_model_list( + controller_url, args.register_api_endpoint_file, vision_arena=False + ) + + return load_demo_single(models, url_params) + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + filename = get_conv_log_filename() + if "llava" in model_selector: + filename = filename.replace("2024", "vision-tmp-2024") + + with open(filename, "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + if not state.regen_support: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], "") + (disable_btn,) * 5 + + +def get_ip(request: gr.Request): + if "cf-connecting-ip" in request.headers: + ip = request.headers["cf-connecting-ip"] + elif "x-forwarded-for" in request.headers: + ip = request.headers["x-forwarded-for"] + if "," in ip: + ip = ip.split(",")[0] + else: + ip = request.client.host + return ip + + +def add_text(state, model_selector, text, request: gr.Request): + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + + all_conv_text = state.conv.get_prompt() + all_conv_text = all_conv_text[-2000:] + "\nuser: " + text + flagged = moderation_filter(all_conv_text, [state.model_name]) + # flagged = moderation_filter(text, [state.model_name]) + if flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + state.conv.append_message(state.conv.roles[0], text) + state.conv.append_message(state.conv.roles[1], None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + images, +): + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + logger.info(f"==== request ====\n{gen_params}") + + if len(images) > 0: + gen_params["images"] = images + + # Stream output + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + yield data + + +def is_limit_reached(model_name, ip): + monitor_url = "http://localhost:9090" + try: + ret = requests.get( + f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1 + ) + obj = ret.json() + return obj + except Exception as e: + logger.info(f"monitor error: {e}") + return None + + +def bot_response( + state, + temperature, + top_p, + max_new_tokens, + request: gr.Request, + apply_rate_limit=True, + use_recommended_config=False, +): + ip = get_ip(request) + logger.info(f"bot_response. ip: {ip}") + start_tstamp = time.time() + temperature = float(temperature) + top_p = float(top_p) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + # This generate call is skipped due to invalid inputs + state.skip_next = False + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if apply_rate_limit: + ret = is_limit_reached(state.model_name, ip) + if ret is not None and ret["is_limit_reached"]: + error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"] + logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}") + state.conv.update_last_message(error_msg) + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + conv, model_name = state.conv, state.model_name + model_api_dict = ( + api_endpoint_info[model_name] if model_name in api_endpoint_info else None + ) + images = conv.get_images() + + if model_api_dict is None: + # Query worker address + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + conv.update_last_message(SERVER_ERROR_MSG) + yield ( + state, + state.to_gradio_chatbot(), + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + # Construct prompt. + # We need to call it here, so it will not be affected by "▌". + prompt = conv.get_prompt() + # Set repetition_penalty + if "t5" in model_name: + repetition_penalty = 1.2 + else: + repetition_penalty = 1.0 + + stream_iter = model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + images, + ) + else: + # Remove system prompt for API-based models unless specified + custom_system_prompt = model_api_dict.get("custom_system_prompt", False) + if not custom_system_prompt: + conv.set_system_message("") + + if use_recommended_config: + recommended_config = model_api_dict.get("recommended_config", None) + if recommended_config is not None: + temperature = recommended_config.get("temperature", temperature) + top_p = recommended_config.get("top_p", top_p) + max_new_tokens = recommended_config.get( + "max_new_tokens", max_new_tokens + ) + + stream_iter = get_api_provider_stream_iter( + conv, + model_name, + model_api_dict, + temperature, + top_p, + max_new_tokens, + state, + ) + + html_code = ' ' + + # conv.update_last_message("▌") + conv.update_last_message(html_code) + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + data = {"text": ""} + for i, data in enumerate(stream_iter): + if data["error_code"] == 0: + output = data["text"].strip() + conv.update_last_message(output + "▌") + # conv.update_last_message(output + html_code) + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f"\n\n(error_code: {data['error_code']})" + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + output = data["text"].strip() + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + except requests.exceptions.RequestException as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + except Exception as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + finish_tstamp = time.time() + logger.info(f"{output}") + + conv.save_new_images( + has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage + ) + + filename = get_conv_log_filename( + is_vision=state.is_vision, has_csam_image=state.has_csam_image + ) + + with open(filename, "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + get_remote_logger().log(data) + + +block_css = """ +.prose { + font-size: 105% !important; +} + +#arena_leaderboard_dataframe table { + font-size: 105%; +} +#full_leaderboard_dataframe table { + font-size: 105%; +} + +.tab-nav button { + font-size: 18px; +} + +.chatbot h1 { + font-size: 130%; +} +.chatbot h2 { + font-size: 120%; +} +.chatbot h3 { + font-size: 110%; +} + +#chatbot .prose { + font-size: 90% !important; +} + +.sponsor-image-about img { + margin: 0 20px; + margin-top: 20px; + height: 40px; + max-height: 100%; + width: auto; + float: left; +} + +.cursor { + display: inline-block; + width: 7px; + height: 1em; + background-color: black; + vertical-align: middle; + animation: blink 1s infinite; +} + +.dark .cursor { + display: inline-block; + width: 7px; + height: 1em; + background-color: white; + vertical-align: middle; + animation: blink 1s infinite; +} + +@keyframes blink { + 0%, 50% { opacity: 1; } + 50.1%, 100% { opacity: 0; } +} + +.app { + max-width: 100% !important; + padding-left: 5% !important; + padding-right: 5% !important; +} + +a { + color: #1976D2; /* Your current link color, a shade of blue */ + text-decoration: none; /* Removes underline from links */ +} +a:hover { + color: #63A4FF; /* This can be any color you choose for hover */ + text-decoration: underline; /* Adds underline on hover */ +} +""" + + +# block_css = """ +# #notice_markdown .prose { +# font-size: 110% !important; +# } +# #notice_markdown th { +# display: none; +# } +# #notice_markdown td { +# padding-top: 6px; +# padding-bottom: 6px; +# } +# #arena_leaderboard_dataframe table { +# font-size: 110%; +# } +# #full_leaderboard_dataframe table { +# font-size: 110%; +# } +# #model_description_markdown { +# font-size: 110% !important; +# } +# #leaderboard_markdown .prose { +# font-size: 110% !important; +# } +# #leaderboard_markdown td { +# padding-top: 6px; +# padding-bottom: 6px; +# } +# #leaderboard_dataframe td { +# line-height: 0.1em; +# } +# #about_markdown .prose { +# font-size: 110% !important; +# } +# #ack_markdown .prose { +# font-size: 110% !important; +# } +# #chatbot .prose { +# font-size: 105% !important; +# } +# .sponsor-image-about img { +# margin: 0 20px; +# margin-top: 20px; +# height: 40px; +# max-height: 100%; +# width: auto; +# float: left; +# } + +# body { +# --body-text-size: 14px; +# } + +# .chatbot h1, h2, h3 { +# margin-top: 8px; /* Adjust the value as needed */ +# margin-bottom: 0px; /* Adjust the value as needed */ +# padding-bottom: 0px; +# } + +# .chatbot h1 { +# font-size: 130%; +# } +# .chatbot h2 { +# font-size: 120%; +# } +# .chatbot h3 { +# font-size: 110%; +# } +# .chatbot p:not(:first-child) { +# margin-top: 8px; +# } + +# .typing { +# display: inline-block; +# } + +# """ + + +def get_model_description_md(models): + model_description_md = """ +| | | | +| ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(models): + minfo = get_model_info(name) + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + + if ct % 3 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 3 == 2: + model_description_md += "\n" + ct += 1 + return model_description_md + + +def build_about(): + about_markdown = """ +# About Us +Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open platform to evaluate LLMs by human preference in the real-world. +We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback dataset. We invite everyone to join us! + +## Open-source contributors +- [Wei-Lin Chiang](https://infwinston.github.io/), [Lianmin Zheng](https://lmzheng.net/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Lisa Dunlap](https://www.lisabdunlap.com/), [Anastasios Angelopoulos](https://people.eecs.berkeley.edu/~angelopoulos/), [Christopher Chou](https://www.linkedin.com/in/chrisychou), [Tianle Li](https://codingwithtim.github.io/), [Siyuan Zhuang](https://www.linkedin.com/in/siyuanzhuang) +- Advisors: [Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) + +## Learn more +- Chatbot Arena [paper](https://arxiv.org/abs/2403.04132), [launch blog](https://lmsys.org/blog/2023-05-03-arena/), [dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md), [policy](https://lmsys.org/blog/2024-03-01-policy/) +- LMSYS-Chat-1M dataset [paper](https://arxiv.org/abs/2309.11998), LLM Judge [paper](https://arxiv.org/abs/2306.05685) + +## Contact Us +- Follow our [X](https://x.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at lmsys.org@gmail.com +- File issues on [GitHub](https://github.com/lm-sys/FastChat) +- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys) + +## Acknowledgment +We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support. +We also thank [UC Berkeley SkyLab](https://sky.cs.berkeley.edu/), [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [a16z](https://www.a16z.com/), [Together AI](https://www.together.ai/), [Hyperbolic](https://hyperbolic.xyz/), [RunPod](https://runpod.io), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. Learn more about partnership [here](https://lmsys.org/donations/). + + +""" + gr.Markdown(about_markdown, elem_id="about_markdown") + + +def build_single_model_ui(models, add_promotion_links=False): + promotion = ( + f""" +[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena) + +{SURVEY_LINK} + +## 👇 Choose any model to chat +""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Large Language Models +{promotion} +""" + + state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Group(elem_id="share-region-named"): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + with gr.Row(): + with gr.Accordion( + f"🔍 Expand to see the descriptions of {len(models)} models", + open=False, + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + chatbot = gr.Chatbot( + elem_id="chatbot", + label="Scroll down and start chatting", + height=650, + show_copy_button=True, + ) + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + if add_promotion_links: + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + send_btn.click( + add_text, + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + return [state, model_selector] + + +def build_demo(models): + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + + state, model_selector = build_single_model_ui(models) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + demo.load( + load_demo, + [url_params], + [ + state, + model_selector, + ], + js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--register-api-endpoint-file", + type=str, + help="Register API-based model endpoints from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + parser.add_argument( + "--use-remote-storage", + action="store_true", + default=False, + help="Uploads image files to google cloud storage if set to true", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) + models, all_models = get_model_list( + args.controller_url, args.register_api_endpoint_file, vision_arena=False + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo(models) + demo.queue( + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + root_path=args.gradio_root_path, + ) diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py new file mode 100644 index 0000000000000000000000000000000000000000..14f254bf39177cdc3388434323d3dccec1d54f7f --- /dev/null +++ b/fastchat/serve/gradio_web_server_multi.py @@ -0,0 +1,313 @@ +""" +The gradio demo server with multiple tabs. +It supports chatting with a single model or chatting with two models side-by-side. +""" + +import argparse +import pickle +import time + +import gradio as gr + +from fastchat.serve.gradio_block_arena_anony import ( + build_side_by_side_ui_anony, + load_demo_side_by_side_anony, + set_global_vars_anony, +) +from fastchat.serve.gradio_block_arena_named import ( + build_side_by_side_ui_named, + load_demo_side_by_side_named, + set_global_vars_named, +) +from fastchat.serve.gradio_block_arena_vision import ( + build_single_vision_language_model_ui, +) +from fastchat.serve.gradio_block_arena_vision_anony import ( + build_side_by_side_vision_ui_anony, + load_demo_side_by_side_vision_anony, +) +from fastchat.serve.gradio_block_arena_vision_named import ( + build_side_by_side_vision_ui_named, +) + +from fastchat.serve.gradio_web_server import ( + set_global_vars, + block_css, + build_single_model_ui, + build_about, + get_model_list, + load_demo_single, + get_ip, +) +from fastchat.serve.monitor.monitor import build_leaderboard_tab +from fastchat.utils import ( + build_logger, + get_window_url_params_js, + get_window_url_params_with_tos_js, + alert_js, + parse_gradio_auth_creds, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + + +def load_demo(url_params, request: gr.Request): + global models, all_models, vl_models, all_vl_models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + + inner_selected = 0 + if "arena" in url_params: + inner_selected = 0 + elif "vision" in url_params: + inner_selected = 1 + elif "compare" in url_params: + inner_selected = 1 + elif "direct" in url_params or "model" in url_params: + inner_selected = 3 + elif "leaderboard" in url_params: + inner_selected = 4 + elif "about" in url_params: + inner_selected = 5 + + if args.model_list_mode == "reload": + models, all_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=False, + ) + + vl_models, all_vl_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=True, + ) + + single_updates = load_demo_single(models, url_params) + side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params) + side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + + side_by_side_vision_anony_updates = load_demo_side_by_side_vision_anony( + all_models, all_vl_models, url_params + ) + + return ( + (gr.Tabs(selected=inner_selected),) + + single_updates + + side_by_side_anony_updates + + side_by_side_named_updates + + side_by_side_vision_anony_updates + ) + + +def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + head_js = """ + +""" + if args.ga_id is not None: + head_js += f""" + + + """ + text_size = gr.themes.sizes.text_lg + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(text_size=text_size), + css=block_css, + head=head_js, + ) as demo: + with gr.Tabs() as inner_tabs: + if args.vision_arena: + with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab: + arena_tab.select(None, None, None, js=load_js) + side_by_side_anony_list = build_side_by_side_vision_ui_anony( + all_models, + all_vl_models, + random_questions=args.random_questions, + ) + else: + with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab: + arena_tab.select(None, None, None, js=load_js) + side_by_side_anony_list = build_side_by_side_ui_anony(models) + + with gr.Tab("⚔️ Arena (side-by-side)", id=2) as side_by_side_tab: + side_by_side_tab.select(None, None, None, js=alert_js) + side_by_side_named_list = build_side_by_side_ui_named(models) + + with gr.Tab("💬 Direct Chat", id=3) as direct_tab: + direct_tab.select(None, None, None, js=alert_js) + single_model_list = build_single_model_ui( + models, add_promotion_links=True + ) + + demo_tabs = ( + [inner_tabs] + + single_model_list + + side_by_side_anony_list + + side_by_side_named_list + ) + + if elo_results_file: + with gr.Tab("🏆 Leaderboard", id=4): + build_leaderboard_tab( + elo_results_file, leaderboard_table_file, show_plot=True + ) + + with gr.Tab("ℹ️ About Us", id=5): + about = build_about() + + url_params = gr.JSON(visible=False) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + demo.load( + load_demo, + [url_params], + demo_tabs, + js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time.", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--vision-arena", action="store_true", help="Show tabs for vision arena." + ) + parser.add_argument( + "--random-questions", type=str, help="Load random questions from a JSON file" + ) + parser.add_argument( + "--register-api-endpoint-file", + type=str, + help="Register API-based model endpoints from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + default=None, + ) + parser.add_argument( + "--elo-results-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--leaderboard-table-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + parser.add_argument( + "--ga-id", + type=str, + help="the Google Analytics ID", + default=None, + ) + parser.add_argument( + "--use-remote-storage", + action="store_true", + default=False, + help="Uploads image files to google cloud storage if set to true", + ) + parser.add_argument( + "--password", + type=str, + help="Set the password for the gradio web server", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) + set_global_vars_named(args.moderate) + set_global_vars_anony(args.moderate) + models, all_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=False, + ) + + vl_models, all_vl_models = get_model_list( + args.controller_url, + args.register_api_endpoint_file, + vision_arena=True, + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo( + models, + all_vl_models, + args.elo_results_file, + args.leaderboard_table_file, + ) + demo.queue( + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + root_path=args.gradio_root_path, + show_api=False, + ) diff --git a/fastchat/serve/huggingface_api.py b/fastchat/serve/huggingface_api.py new file mode 100644 index 0000000000000000000000000000000000000000..8022fbc93e9f2d4240eb67ff95061928cee81bbd --- /dev/null +++ b/fastchat/serve/huggingface_api.py @@ -0,0 +1,73 @@ +""" +Use FastChat with Hugging Face generation APIs. + +Usage: +python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 +""" +import argparse + +import torch + +from fastchat.model import load_model, get_conversation_template, add_model_args + + +@torch.inference_mode() +def main(args): + # Load model + model, tokenizer = load_model( + args.model_path, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + revision=args.revision, + debug=args.debug, + ) + + # Build the prompt with a conversation template + msg = args.message + conv = get_conversation_template(args.model_path) + conv.append_message(conv.roles[0], msg) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + # Run inference + inputs = tokenizer([prompt], return_tensors="pt").to(args.device) + output_ids = model.generate( + **inputs, + do_sample=True if args.temperature > 1e-5 else False, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + max_new_tokens=args.max_new_tokens, + ) + + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(inputs["input_ids"][0]) :] + outputs = tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + + # Print results + print(f"{conv.roles[0]}: {msg}") + print(f"{conv.roles[1]}: {outputs}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=1024) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--message", type=str, default="Hello! Who are you?") + args = parser.parse_args() + + # Reset default repetition penalty for T5 models. + if "t5" in args.model_path and args.repetition_penalty == 1.0: + args.repetition_penalty = 1.2 + + main(args) diff --git a/fastchat/serve/huggingface_api_worker.py b/fastchat/serve/huggingface_api_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed8e6c8cc4a39904927532beab24874b7762a17 --- /dev/null +++ b/fastchat/serve/huggingface_api_worker.py @@ -0,0 +1,415 @@ +""" +A model worker that calls huggingface inference endpoint. + +Register models in a JSON file with the following format: +{ + "falcon-180b-chat": { + "model_name": "falcon-180B-chat", + "api_base": "https://api-inference.huggingface.co/models", + "model_path": "tiiuae/falcon-180B-chat", + "token": "hf_XXX", + "context_length": 2048 + }, + "zephyr-7b-beta": { + "model_name": "zephyr-7b-beta", + "model_path": "", + "api_base": "xxx", + "token": "hf_XXX", + "context_length": 4096 + } +} + +"model_path", "api_base", "token", and "context_length" are necessary, while others are optional. +""" +import argparse +import asyncio +import json +import uuid +import os +from typing import List, Optional + +import requests +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from huggingface_hub import InferenceClient + +from fastchat.constants import SERVER_ERROR_MSG, ErrorCode +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.utils import build_logger + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + +workers = [] +worker_map = {} +app = FastAPI() + + +# reference to +# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 +def get_gen_kwargs( + params, + seed: Optional[int] = None, +): + stop = params.get("stop", None) + if isinstance(stop, list): + stop_sequences = stop + elif isinstance(stop, str): + stop_sequences = [stop] + else: + stop_sequences = [] + gen_kwargs = { + "do_sample": True, + "return_full_text": bool(params.get("echo", False)), + "max_new_tokens": int(params.get("max_new_tokens", 256)), + "top_p": float(params.get("top_p", 1.0)), + "temperature": float(params.get("temperature", 1.0)), + "stop_sequences": stop_sequences, + "repetition_penalty": float(params.get("repetition_penalty", 1.0)), + "top_k": params.get("top_k", None), + "seed": seed, + } + if gen_kwargs["top_p"] == 1: + gen_kwargs["top_p"] = 0.9999999 + if gen_kwargs["top_p"] == 0: + gen_kwargs.pop("top_p") + if gen_kwargs["temperature"] == 0: + gen_kwargs.pop("temperature") + gen_kwargs["do_sample"] = False + return gen_kwargs + + +def could_be_stop(text, stop): + for s in stop: + if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): + return True + return False + + +class HuggingfaceApiWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + api_base: str, + token: str, + context_length: int, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: Optional[str] = None, + seed: Optional[int] = None, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + self.model_path = model_path + self.api_base = api_base + self.token = token + self.context_len = context_length + self.seed = seed + + logger.info( + f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." + ) + + if not no_register: + self.init_heart_beat() + + def count_token(self, params): + # No tokenizer here + ret = { + "count": 0, + "error_code": 0, + } + return ret + + def generate_stream_gate(self, params): + self.call_ct += 1 + + prompt = params["prompt"] + gen_kwargs = get_gen_kwargs(params, seed=self.seed) + stop = gen_kwargs["stop_sequences"] + if "falcon" in self.model_path and "chat" in self.model_path: + stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) + stop = list(set(stop)) + gen_kwargs["stop_sequences"] = stop + + logger.info(f"prompt: {prompt}") + logger.info(f"gen_kwargs: {gen_kwargs}") + + try: + if self.model_path == "": + url = f"{self.api_base}" + else: + url = f"{self.api_base}/{self.model_path}" + client = InferenceClient(url, token=self.token) + res = client.text_generation( + prompt, stream=True, details=True, **gen_kwargs + ) + + reason = None + text = "" + for chunk in res: + if chunk.token.special: + continue + text += chunk.token.text + + s = next((x for x in stop if text.endswith(x)), None) + if s is not None: + text = text[: -len(s)] + reason = "stop" + break + if could_be_stop(text, stop): + continue + if ( + chunk.details is not None + and chunk.details.finish_reason is not None + ): + reason = chunk.details.finish_reason + if reason not in ["stop", "length"]: + reason = None + ret = { + "text": text, + "error_code": 0, + "finish_reason": reason, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def get_embeddings(self, params): + raise NotImplementedError() + + +def release_worker_semaphore(worker): + worker.semaphore.release() + + +def acquire_worker_semaphore(worker): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(worker): + background_tasks = BackgroundTasks() + background_tasks.add_task(lambda: release_worker_semaphore(worker)) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks(worker) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + output = worker.generate_gate(params) + release_worker_semaphore(worker) + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + embedding = worker.get_embeddings(params) + release_worker_semaphore(worker) + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_huggingface_api_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + # all model-related parameters are listed in --model-info-file + parser.add_argument( + "--model-info-file", + type=str, + required=True, + help="Huggingface API model's info file path", + ) + + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + with open(args.model_info_file, "r", encoding="UTF-8") as f: + model_info = json.load(f) + + logger.info(f"args: {args}") + + model_path_list = [] + api_base_list = [] + token_list = [] + context_length_list = [] + model_names_list = [] + conv_template_list = [] + + for m in model_info: + model_path_list.append(model_info[m]["model_path"]) + api_base_list.append(model_info[m]["api_base"]) + token_list.append(model_info[m]["token"]) + + context_length = model_info[m]["context_length"] + model_names = model_info[m].get("model_names", [m.split("/")[-1]]) + if isinstance(model_names, str): + model_names = [model_names] + conv_template = model_info[m].get("conv_template", None) + + context_length_list.append(context_length) + model_names_list.append(model_names) + conv_template_list.append(conv_template) + + logger.info(f"Model paths: {model_path_list}") + logger.info(f"API bases: {api_base_list}") + logger.info(f"Tokens: {token_list}") + logger.info(f"Context lengths: {context_length_list}") + logger.info(f"Model names: {model_names_list}") + logger.info(f"Conv templates: {conv_template_list}") + + for ( + model_names, + conv_template, + model_path, + api_base, + token, + context_length, + ) in zip( + model_names_list, + conv_template_list, + model_path_list, + api_base_list, + token_list, + context_length_list, + ): + m = HuggingfaceApiWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + api_base, + token, + context_length, + model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + conv_template=conv_template, + seed=args.seed, + ) + workers.append(m) + for name in model_names: + worker_map[name] = m + + # register all the models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_huggingface_api_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6d155aab7c6cb32ca5fdb10d5661f1b298908f93 --- /dev/null +++ b/fastchat/serve/inference.py @@ -0,0 +1,555 @@ +"""Inference for FastChat models.""" +import abc +import gc +import json +import math +import os +import sys +import time +from typing import Iterable, Optional, Dict +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + if hasattr(model, "device"): + device = model.device + + # Read parameters + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. + raise NotImplementedError + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values = out = None + token_logprobs = [None] # The first token has no logprobs. + sent_interrupt = False + finish_reason = None + stopped = False + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(input_ids=start_ids, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs + if echo + else token_logprobs[input_echo_len:], + "top_logprobs": [{}] + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # Finish stream event, which contains finish reason + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + @abc.abstractmethod + def print_output(self, text: str): + """Print output.""" + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + # Model + model, tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + revision=revision, + debug=debug, + ) + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + is_xft = "xft" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = None + + while True: + if not history or not conv: + conv = new_chat() + + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + + if inp == "!!exit" or not inp: + print("exit...") + break + elif inp == "!!reset": + print("resetting...") + conv = new_chat() + continue + elif inp == "!!remove": + print("removing last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + reload_conv(conv) + else: + print("No messages to remove.") + continue + elif inp == "!!regen": + print("regenerating last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + reload_conv(conv) + # Set inp to previous message + inp = conv.messages.pop()[1] + else: + # Shouldn't happen in normal circumstances + print("No user message to regenerate from.") + continue + else: + print("No messages to regenerate.") + continue + elif inp.startswith("!!save"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".json" + + print("saving...", filename) + with open(filename, "w") as outfile: + json.dump(conv.dict(), outfile) + continue + elif inp.startswith("!!load"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!load ") + continue + else: + filename = args[1] + + # Check if file exists and add .json if needed + if not os.path.exists(filename): + if (not filename.endswith(".json")) and os.path.exists( + filename + ".json" + ): + filename += ".json" + else: + print("file not found:", filename) + continue + + print("loading...", filename) + with open(filename, "r") as infile: + new_conv = json.load(infile) + + conv = get_conv_template(new_conv["template_name"]) + conv.set_system_message(new_conv["system_message"]) + conv.messages = new_conv["messages"] + reload_conv(conv) + continue + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + try: + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func( + model, + tokenizer, + gen_params, + device, + context_len=context_len, + judge_sent_end=judge_sent_end, + ) + t = time.time() + outputs = chatio.stream_output(output_stream) + duration = time.time() - t + conv.update_last_message(outputs.strip()) + + if debug: + num_tokens = len(tokenizer.encode(outputs)) + msg = { + "conv_template": conv.name, + "prompt": prompt, + "outputs": outputs, + "speed (token/s)": round(num_tokens / duration, 2), + } + print(f"\n{msg}\n") + + except KeyboardInterrupt: + print("stopped generation.") + # If generation didn't finish + if conv.messages[-1][1] is None: + conv.messages.pop() + # Remove last user message, so there isn't a double up + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + + reload_conv(conv) diff --git a/fastchat/serve/launch_all_serve.py b/fastchat/serve/launch_all_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..2f4ad7b0b134d1699ff8ba0d95d8039ec3c1f204 --- /dev/null +++ b/fastchat/serve/launch_all_serve.py @@ -0,0 +1,284 @@ +""" +Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" + +Workers are listed in format of `model-path`@`host`@`port` + +The key mechanism behind this scripts is: + 1, execute shell cmd to launch the controller/worker/openai-api-server; + 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly. +Note that a few of non-critical `fastchat.serve` cmd options are not supported currently. +""" +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import subprocess +import re +import argparse + +LOGDIR = "./logs/" + +if not os.path.exists(LOGDIR): + os.makedirs(LOGDIR) + +parser = argparse.ArgumentParser() +# ------multi worker----------------- +parser.add_argument( + "--model-path-address", + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@port", +) +# ---------------controller------------------------- + +parser.add_argument("--controller-host", type=str, default="localhost") +parser.add_argument("--controller-port", type=int, default=21001) +parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", +) +controller_args = ["controller-host", "controller-port", "dispatch-method"] + +# ----------------------worker------------------------------------------ + +parser.add_argument("--worker-host", type=str, default="localhost") +parser.add_argument("--worker-port", type=int, default=21002) +# parser.add_argument("--worker-address", type=str, default="http://localhost:21002") +# parser.add_argument( +# "--controller-address", type=str, default="http://localhost:21001" +# ) +parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", +) +parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", +) +parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", +) +parser.add_argument( + "--gpus", + type=str, + default="0", + help="A single GPU like 1 or multiple GPUs like 0,2", +) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", +) +parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") +parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", +) +parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Load quantized model. The path to the local GPTQ checkpoint.", +) +parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization", +) +parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", +) +parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", +) +parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", +) +parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", +) +parser.add_argument("--stream-interval", type=int, default=2) +parser.add_argument("--no-register", action="store_true") + +worker_args = [ + "worker-host", + "worker-port", + "model-path", + "revision", + "device", + "gpus", + "num-gpus", + "max-gpu-memory", + "load-8bit", + "cpu-offloading", + "gptq-ckpt", + "gptq-wbits", + "gptq-groupsize", + "gptq-act-order", + "model-names", + "limit-worker-concurrency", + "stream-interval", + "no-register", + "controller-address", +] +# -----------------openai server--------------------------- + +parser.add_argument("--server-host", type=str, default="localhost", help="host name") +parser.add_argument("--server-port", type=int, default=8001, help="port number") +parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" +) +# parser.add_argument( +# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" +# ) +# parser.add_argument( +# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" +# ) +# parser.add_argument( +# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" +# ) +parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", +) +server_args = [ + "server-host", + "server-port", + "allow-credentials", + "api-keys", + "controller-address", +] + +args = parser.parse_args() + +args = argparse.Namespace( + **vars(args), + **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"}, +) + +if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + +# 0,controller, model_worker, openai_api_server +# 1, cmd options +# 2,LOGDIR +# 3, log file name +base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" + +# 0 LOGDIR +#! 1 log file name +# 2 controller, worker, openai_api_server +base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do + sleep 1s; + echo "wait {2} running" + done + echo '{2} running' """ + + +def string_args(args, args_list): + args_str = "" + for key, value in args._get_kwargs(): + key = key.replace("_", "-") + if key not in args_list: + continue + + key = key.split("-")[-1] if re.search("port|host", key) else key + if not value: + pass + # 1==True -> True + elif isinstance(value, bool) and value == True: + args_str += f" --{key} " + elif ( + isinstance(value, list) + or isinstance(value, tuple) + or isinstance(value, set) + ): + value = " ".join(value) + args_str += f" --{key} {value} " + else: + args_str += f" --{key} {value} " + + return args_str + + +def launch_worker(item): + log_name = ( + item.split("/")[-1] + .split("\\")[-1] + .replace("-", "_") + .replace("@", "_") + .replace(".", "_") + ) + + args.model_path, args.worker_host, args.worker_port = item.split("@") + print("*" * 80) + worker_str_args = string_args(args, worker_args) + print(worker_str_args) + worker_sh = base_launch_sh.format( + "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}" + ) + worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker") + subprocess.run(worker_sh, shell=True, check=True) + subprocess.run(worker_check_sh, shell=True, check=True) + + +def launch_all(): + controller_str_args = string_args(args, controller_args) + controller_sh = base_launch_sh.format( + "controller", controller_str_args, LOGDIR, "controller" + ) + controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller") + subprocess.run(controller_sh, shell=True, check=True) + subprocess.run(controller_check_sh, shell=True, check=True) + + if isinstance(args.model_path_address, str): + launch_worker(args.model_path_address) + else: + for idx, item in enumerate(args.model_path_address): + print(f"loading {idx}th model:{item}") + launch_worker(item) + + server_str_args = string_args(args, server_args) + server_sh = base_launch_sh.format( + "openai_api_server", server_str_args, LOGDIR, "openai_api_server" + ) + server_check_sh = base_check_sh.format( + LOGDIR, "openai_api_server", "openai_api_server" + ) + subprocess.run(server_sh, shell=True, check=True) + subprocess.run(server_check_sh, shell=True, check=True) + + +if __name__ == "__main__": + launch_all() diff --git a/fastchat/serve/lightllm_worker.py b/fastchat/serve/lightllm_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0e21b68e3a9c03556937987910b590344b452f --- /dev/null +++ b/fastchat/serve/lightllm_worker.py @@ -0,0 +1,512 @@ +""" +A model worker that executes the model based on LightLLM. + +See documentations at docs/lightllm_integration.md +""" + +import argparse +import asyncio +import json +import os +import torch +import uvicorn + +from transformers import AutoConfig + +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) + +from lightllm.server.sampling_params import SamplingParams +from lightllm.server.multimodal_params import MultimodalParams +from lightllm.server.httpserver.manager import HttpServerManager +from lightllm.server.detokenization.manager import start_detokenization_process +from lightllm.server.router.manager import start_router_process +from lightllm.server.req_id_generator import ReqIDGenerator + +from lightllm.utils.net_utils import alloc_can_use_network_port +from lightllm.utils.start_utils import start_submodule_processes +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() +g_id_gen = ReqIDGenerator() + + +class LightLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + tokenizer, + context_len, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: LightLLM worker..." + ) + self.tokenizer = tokenizer + self.context_len = context_len + + self.is_first = True + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + prompt = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + max_new_tokens = params.get("max_new_tokens", 256) + echo = params.get("echo", True) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + + request = params.get("request", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + if self.is_first: + loop = asyncio.get_event_loop() + loop.create_task(httpserver_manager.handle_loop()) + self.is_first = False + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + do_sample=temperature > 0.0, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repetition_penalty=repetition_penalty, + max_new_tokens=max_new_tokens, + stop_sequences=list(stop), + ) + sampling_params.verify() + + results_generator = httpserver_manager.generate( + prompt, sampling_params, request_id, MultimodalParams() + ) + + completion_tokens = 0 + text_outputs = "" + cumulative_logprob = 0.0 + + async for request_output, metadata, finish_status in results_generator: + text_outputs += request_output + completion_tokens += 1 + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + if type(finish_status) is bool: # compatibility with old version + finish_reason = "stop" if finish_status else None + else: + finish_reason = finish_status.get_finish_reason() + + if request and await request.is_disconnected(): + await httpserver_manager.abort(request_id) + finish_reason = "abort" + + logprob = metadata.get("logprob", None) + if logprob is not None: + cumulative_logprob += logprob + + prompt_tokens = metadata["prompt_tokens"] + ret = { + "text": prompt + text_outputs if echo else text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": cumulative_logprob, + } + + if finish_reason is not None: + yield ( + json.dumps({**ret, "finish_reason": None}, ensure_ascii=False) + + "\0" + ).encode("utf-8") + yield ( + json.dumps({**ret, "finish_reason": finish_reason}, ensure_ascii=False) + + "\0" + ).encode("utf-8") + + if finish_reason is not None: # In case of abort, we need to break the loop + break + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await httpserver_manager.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = g_id_gen.generate_id() + params["request_id"] = request_id + params["request"] = request + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = g_id_gen.generate_id() + params["request_id"] = request_id + params["request"] = request + output = await worker.generate(params) + release_worker_semaphore() + await httpserver_manager.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + + parser.add_argument( + "--model-path", + dest="model_dir", + type=str, + default=None, + help="the model weight dir path, the app will load config, weights and tokenizer from this dir", + ) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + + parser.add_argument( + "--tokenizer_mode", + type=str, + default="slow", + help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test, + when you want to get best performance, try auto mode""", + ) + parser.add_argument( + "--load_way", + type=str, + default="HF", + help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)", + ) + parser.add_argument( + "--max_total_token_num", + type=int, + default=6000, + help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", + ) + parser.add_argument( + "--batch_max_tokens", + type=int, + default=None, + help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", + ) + parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id") + parser.add_argument( + "--running_max_req_size", + type=int, + default=1000, + help="the max size for forward requests in the same time", + ) + parser.add_argument( + "--tp", type=int, default=1, help="model tp parral size, the default is 1" + ) + parser.add_argument( + "--max_req_input_len", + type=int, + default=None, + help="the max value for req input tokens num. If None, it will be derived from the config.", + ) + parser.add_argument( + "--max_req_total_len", + type=int, + default=None, + help="the max value for req_input_len + req_output_len. If None, it will be derived from the config.", + ) + parser.add_argument( + "--mode", + type=str, + default=[], + nargs="+", + help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding + | triton_gqa_attention | triton_gqa_flashdecoding] + [triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight], + triton_flashdecoding mode is for long context, current support llama llama2 qwen; + triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; + triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; + ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; + ppl_fp16 mode use ppl fast fp16 decode attention kernel; + triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights; + you need to read source code to make sure the supported detail mode for all models""", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) + parser.add_argument( + "--disable_log_stats", + action="store_true", + help="disable logging throughput stats.", + ) + parser.add_argument( + "--log_stats_interval", + type=int, + default=10, + help="log stats interval in second.", + ) + + parser.add_argument( + "--router_token_ratio", + type=float, + default=0.0, + help="token ratio to control router dispatch", + ) + parser.add_argument( + "--router_max_new_token_len", + type=int, + default=1024, + help="the request max new token len for router", + ) + + parser.add_argument( + "--no_skipping_special_tokens", + action="store_true", + help="whether to skip special tokens when decoding", + ) + parser.add_argument( + "--no_spaces_between_special_tokens", + action="store_true", + help="whether to add spaces between special tokens when decoding", + ) + + parser.add_argument( + "--splitfuse_mode", action="store_true", help="use splitfuse mode" + ) + parser.add_argument( + "--splitfuse_block_size", type=int, default=256, help="splitfuse block size" + ) + parser.add_argument( + "--prompt_cache_strs", + type=str, + default=[], + nargs="+", + help="""prompt cache strs""", + ) + parser.add_argument( + "--cache_capacity", + type=int, + default=200, + help="cache server capacity for multimodal resources", + ) + parser.add_argument( + "--cache_reserved_ratio", + type=float, + default=0.5, + help="cache server reserved capacity ratio after clear", + ) + parser.add_argument( + "--return_all_prompt_logprobs", + action="store_true", + help="return all prompt tokens logprobs", + ) + parser.add_argument( + "--long_truncation_mode", + type=str, + choices=[None, "head", "center"], + default=None, + help="""use to select the handle way when input token len > max_req_input_len. + None : raise Exception + head : remove some head tokens to make input token len <= max_req_input_len + center : remove some tokens in center loc to make input token len <= max_req_input_len""", + ) + + args = parser.parse_args() + + # 非splitfuse 模式,不支持 prompt cache 特性 + if not args.splitfuse_mode: + assert len(args.prompt_cache_strs) == 0 + + model_config = AutoConfig.from_pretrained(args.model_dir) + context_length = get_context_length(model_config) + + if args.max_req_input_len is None: + args.max_req_input_len = context_length - 1 + if args.max_req_total_len is None: + args.max_req_total_len = context_length + + assert args.max_req_input_len < args.max_req_total_len + assert args.max_req_total_len <= args.max_total_token_num + + if not args.splitfuse_mode: + # 普通模式下 + if args.batch_max_tokens is None: + batch_max_tokens = int(1 / 6 * args.max_total_token_num) + batch_max_tokens = max(batch_max_tokens, args.max_req_total_len) + args.batch_max_tokens = batch_max_tokens + else: + assert ( + args.batch_max_tokens >= args.max_req_total_len + ), "batch_max_tokens must >= max_req_total_len" + else: + # splitfuse 模式下 + # assert args.batch_max_tokens is not None, "need to set by yourself" + if args.batch_max_tokens is None: + batch_max_tokens = int(1 / 6 * args.max_total_token_num) + batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size) + args.batch_max_tokens = batch_max_tokens + + can_use_ports = alloc_can_use_network_port(num=6 + args.tp) + + assert can_use_ports is not None, "Can not alloc enough free ports." + ( + router_port, + detokenization_port, + httpserver_port, + visual_port, + cache_port, + nccl_port, + ) = can_use_ports[0:6] + args.nccl_port = nccl_port + model_rpc_ports = can_use_ports[6:] + + global httpserver_manager + httpserver_manager = HttpServerManager( + args, + router_port=router_port, + cache_port=cache_port, + visual_port=visual_port, + httpserver_port=httpserver_port, + enable_multimodal=False, + ) + + start_submodule_processes( + start_funcs=[start_router_process, start_detokenization_process], + start_args=[ + (args, router_port, detokenization_port, model_rpc_ports), + (args, detokenization_port, httpserver_port), + ], + ) + worker = LightLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_dir, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + args.conv_template, + httpserver_manager.tokenizer, + context_length, + ) + + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/mlx_worker.py b/fastchat/serve/mlx_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e85f848eadbf2e492e0298f5a3609e50f12e59 --- /dev/null +++ b/fastchat/serve/mlx_worker.py @@ -0,0 +1,288 @@ +""" +A model worker using Apple MLX + +https://github.com/ml-explore/mlx-examples/tree/main/llms + +Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py + +You must install MLX python: + +pip install mlx-lm +""" + +import argparse +import asyncio +import atexit +import json +from typing import List +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.concurrency import run_in_threadpool +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +import mlx.core as mx +from mlx_lm import load, generate +from mlx_lm.utils import generate_step + +app = FastAPI() + + +class MLXWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: "MLX", + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." + ) + + self.model_name = model_path + self.mlx_model, self.mlx_tokenizer = load(model_path) + + self.tokenizer = self.mlx_tokenizer + # self.context_len = get_context_length( + # llm_engine.engine.model_config.hf_config) + self.context_len = 2048 # hard code for now -- not sure how to get in MLX + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + print("Stop patterns: ", stop) + + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + tokens = [] + skip = 0 + + context_mlx = mx.array(self.tokenizer.encode(context)) + + finish_reason = "length" + + iterator = await run_in_threadpool( + generate_step, context_mlx, self.mlx_model, temperature + ) + + for i in range(max_new_tokens): + (token, _) = await run_in_threadpool(next, iterator) + if token == self.mlx_tokenizer.eos_token_id: + finish_reason = "stop" + break + tokens.append(token.item()) + tokens_decoded = self.mlx_tokenizer.decode(tokens) + last_token_decoded = self.mlx_tokenizer.decode([token.item()]) + skip = len(tokens_decoded) + + partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) + + if partial_stop: + finish_reason = "stop" + break + + ret = { + "text": tokens_decoded, + "error_code": 0, + "usage": { + "prompt_tokens": len(context), + "completion_tokens": len(tokens), + "total_tokens": len(context) + len(tokens), + }, + "cumulative_logprob": [], + "finish_reason": None, # hard code for now + } + # print(ret) + yield (json.dumps(ret) + "\0").encode() + ret = { + "text": self.mlx_tokenizer.decode(tokens), + "error_code": 0, + "usage": {}, + "cumulative_logprob": [], + "finish_reason": finish_reason, + } + yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + print("trying to abort but not implemented") + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + output = await worker.generate(params) + release_worker_semaphore() + # await engine.abort(request_id) + print("Trying to abort but not implemented") + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +worker = None + + +def cleanup_at_exit(): + global worker + print("Cleaning up...") + del worker + + +atexit.register(cleanup_at_exit) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="microsoft/phi-2") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + + args, unknown = parser.parse_known_args() + + if args.model_path: + args.model = args.model_path + + worker = MLXWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + 1024, + False, + "MLX", + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..683a78556dd6062e395d23effa3faa77a422cf58 --- /dev/null +++ b/fastchat/serve/model_worker.py @@ -0,0 +1,425 @@ +""" +A model worker that executes the model. +""" +import argparse +import base64 +import gc +import json +import os +from typing import List, Optional +import uuid + +import torch +import torch.nn.functional as F +from transformers import set_seed +import uvicorn + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.base_model_worker import BaseModelWorker, app +from fastchat.utils import ( + build_logger, + get_context_length, + str_to_torch_dtype, +) + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + revision: str = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + stream_interval: int = 2, + conv_template: Optional[str] = None, + embed_in_truncate: bool = False, + seed: Optional[int] = None, + debug: bool = False, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, + revision=revision, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + debug=debug, + ) + self.device = device + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function(self.model, model_path) + self.stream_interval = stream_interval + self.embed_in_truncate = embed_in_truncate + self.seed = seed + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + if self.device == "npu": + import torch_npu + + torch_npu.npu.set_device("npu:0") + self.call_ct += 1 + + try: + if self.seed is not None: + set_seed(self.seed) + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + + if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling: + sum_embeddings = data[:, 0] + else: + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: + embeddings = embeddings.cpu() + return [ + base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings + ] + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + embedding, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + if ( + not hasattr(self.model, "use_cls_pooling") + or not self.model.use_cls_pooling + ): + embedding = embedding / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i : i + self.context_len] + chunk_attention_mask = attention_mask[:, i : i + self.context_len] + + # add cls token and mask to get cls embedding + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + cls_tokens = ( + torch.zeros( + (chunk_input_ids.size(0), 1), + dtype=chunk_input_ids.dtype, + device=chunk_input_ids.device, + ) + + tokenizer.cls_token_id + ) + chunk_input_ids = torch.cat( + [cls_tokens, chunk_input_ids], dim=-1 + ) + mask = torch.ones( + (chunk_attention_mask.size(0), 1), + dtype=chunk_attention_mask.dtype, + device=chunk_attention_mask.device, + ) + chunk_attention_mask = torch.cat( + [mask, chunk_attention_mask], dim=-1 + ) + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + if ( + hasattr(self.model, "use_cls_pooling") + and self.model.use_cls_pooling + ): + all_embeddings.append(chunk_embeddings * token_num) + else: + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--embed-in-truncate", action="store_true") + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--debug", type=bool, default=False, help="Print debugging messages" + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + revision=args.revision, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, + seed=args.seed, + debug=args.debug, + ) + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/monitor/basic_stats.py b/fastchat/serve/monitor/basic_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1a8793d00ae2cd40da085e09448428f9ef5cff --- /dev/null +++ b/fastchat/serve/monitor/basic_stats.py @@ -0,0 +1,220 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd # pandas>=2.0.3 +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + + +NUM_SERVERS = 14 +LOG_ROOT_DIR = "~/fastchat_logs" + + +def get_log_files(max_num_files=None): + log_root = os.path.expanduser(LOG_ROOT_DIR) + filenames = [] + for i in range(NUM_SERVERS): + for filename in os.listdir(f"{log_root}/server{i}"): + if filename.endswith("-conv.json"): + filepath = f"{log_root}/server{i}/{filename}" + name_tstamp_tuple = (filepath, os.path.getmtime(filepath)) + filenames.append(name_tstamp_tuple) + # sort by tstamp + filenames = sorted(filenames, key=lambda x: x[1]) + filenames = [x[0] for x in filenames] + + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def load_log_files(filename): + data = [] + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + data.append( + dict( + type=row["type"], + tstamp=row["tstamp"], + model=row.get("model", ""), + models=row.get("models", ["", ""]), + ) + ) + return data + + +def load_log_files_parallel(log_files, num_threads=16): + data_all = [] + from multiprocessing import Pool + + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def get_anony_vote_df(df): + anony_vote_df = df[ + df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) + ] + anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")] + return anony_vote_df + + +def merge_counts(series, on, names): + ret = pd.merge(series[0], series[1], on=on) + for i in range(2, len(series)): + ret = pd.merge(ret, series[i], on=on) + ret = ret.reset_index() + old_names = list(ret.columns)[-len(series) :] + rename = {old_name: new_name for old_name, new_name in zip(old_names, names)} + ret = ret.rename(columns=rename) + return ret + + +def report_basic_stats(log_files): + df_all = load_log_files_parallel(log_files) + df_all = pd.DataFrame(df_all) + now_t = df_all["tstamp"].max() + df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] + df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)] + anony_vote_df_all = get_anony_vote_df(df_all) + + # Chat trends + chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in df_all[df_all["type"] == "chat"]["tstamp"] + ] + chat_dates_counts = pd.value_counts(chat_dates) + vote_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in anony_vote_df_all["tstamp"] + ] + vote_dates_counts = pd.value_counts(vote_dates) + chat_dates_bar = go.Figure( + data=[ + go.Bar( + name="Anony. Vote", + x=vote_dates_counts.index, + y=vote_dates_counts, + text=[f"{val:.0f}" for val in vote_dates_counts], + textposition="auto", + ), + go.Bar( + name="Chat", + x=chat_dates_counts.index, + y=chat_dates_counts, + text=[f"{val:.0f}" for val in chat_dates_counts], + textposition="auto", + ), + ] + ) + chat_dates_bar.update_layout( + barmode="stack", + xaxis_title="Dates", + yaxis_title="Count", + height=300, + width=1200, + ) + + # Model call counts + model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts() + model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts() + model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts() + model_hist = merge_counts( + [model_hist_all, model_hist_1_day, model_hist_1_hour], + on="model", + names=["All", "Last Day", "Last Hour"], + ) + model_hist_md = model_hist.to_markdown(index=False, tablefmt="github") + + # Action counts + action_hist_all = df_all["type"].value_counts() + action_hist_1_day = df_1_day["type"].value_counts() + action_hist_1_hour = df_1_hour["type"].value_counts() + action_hist = merge_counts( + [action_hist_all, action_hist_1_day, action_hist_1_hour], + on="type", + names=["All", "Last Day", "Last Hour"], + ) + action_hist_md = action_hist.to_markdown(index=False, tablefmt="github") + + # Anony vote counts + anony_vote_hist_all = anony_vote_df_all["type"].value_counts() + anony_vote_df_1_day = get_anony_vote_df(df_1_day) + anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts() + # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour) + # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts() + anony_vote_hist = merge_counts( + [anony_vote_hist_all, anony_vote_hist_1_day], + on="type", + names=["All", "Last Day"], + ) + anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github") + + # Last 24 hours + chat_1_day = df_1_day[df_1_day["type"] == "chat"] + num_chats_last_24_hours = [] + base = df_1_day["tstamp"].min() + for i in range(24, 0, -1): + left = base + (i - 1) * 3600 + right = base + i * 3600 + num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum() + num_chats_last_24_hours.append(num) + times = [ + datetime.datetime.fromtimestamp( + base + i * 3600, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + for i in range(24, 0, -1) + ] + last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours}) + last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github") + + # Last update datetime + last_updated_tstamp = now_t + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # code.interact(local=locals()) + + return { + "chat_dates_bar": chat_dates_bar, + "model_hist_md": model_hist_md, + "action_hist_md": action_hist_md, + "anony_vote_hist_md": anony_vote_hist_md, + "num_chats_last_24_hours": last_24_hours_md, + "last_updated_datetime": last_updated_datetime, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + basic_stats = report_basic_stats(log_files) + + print(basic_stats["action_hist_md"] + "\n") + print(basic_stats["model_hist_md"] + "\n") + print(basic_stats["anony_vote_hist_md"] + "\n") + print(basic_stats["num_chats_last_24_hours"] + "\n") diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py new file mode 100644 index 0000000000000000000000000000000000000000..223144a32f35635ced7b78c1a077da5fc34f51a6 --- /dev/null +++ b/fastchat/serve/monitor/classify/category.py @@ -0,0 +1,136 @@ +# Tag structure +# - category_tag +# - criteria_v0.1 +# - specificity +# - ... +# - math_v0.1 +# - math +# - if_v0.1 +# - if +# - score +import ast +import re + + +class Category: + def __init__(self): + pass + + @staticmethod + def create_category(name): + if name == "criteria_v0.1": + return CategoryHardPrompt() + elif name == "if_v0.1": + return CategoryIF() + elif name == "math_v0.1": + return CategoryMath() + + raise Exception(f"Category name is incorrect: {name}") + + def post_process(self): + pass + + +class CategoryHardPrompt(Category): + def __init__(self): + super().__init__() + self.name_tag = "criteria_v0.1" + self.pattern = re.compile(r"(\[\d(?:\,\s\d)*\])") + self.sys_prompt = "Your task is to evaluate how well the following input prompts can assess the capabilities of advanced AI assistants.\n\nFor the input prompt, please analyze it based on the following 7 criteria.\n1. Specificity: Does the prompt ask for a specific output, such as code, a mathematical solution, a logical simplification, a problem-solving strategy, or a hardware setup recommendation? This specificity allows the AI to demonstrate its ability to understand and generate precise responses.\n2. Domain Knowledge: Does the prompt cover a specific domain, such as programming, mathematics, logic, problem-solving, or hardware setup? Prompts spanning a range of topics test the AI's breadth of knowledge and its ability to apply that knowledge to different domains.\n3. Complexity: Does the prompt vary in complexity, from straightforward tasks to more complex, multi-step problems? This allows evaluators to assess the AI's capability to handle problems of varying difficulty.\n4. Problem-Solving Skills: Does the prompt directly involves the AI to demonstrate active problem-solving skills, such systemically coming up with a solution for a specific setup instead of regurgitating an existing fact? This tests the AI's ability to apply logical reasoning and provide practical solutions.\n5. Creativity: Does the prompt involve a level of creativity in approaching the problem? This criterion tests the AI's ability to provide tailored solutions that take into account the user's specific needs and limitations.\n6. Technical Accuracy: Does the prompt require technical accuracy in the response? This allows evaluators to assess the AI's precision and correctness in technical fields.\n7. Real-world Application: Does the prompt relate to real-world applications, such as setting up a functional system or writing code for a practical use case? This tests the AI's ability to provide practical and actionable information that could be implemented in real-life scenarios.\n\nYou must list the criteria numbers that the prompt satisfies in the format of a Python array. For example, \"[...]\". Do not explain your choice." + self.tags = { + 1: "specificity", + 2: "domain_knowledge", + 3: "complexity", + 4: "problem_solving", + 5: "creativity", + 6: "technical_accuracy", + 7: "real_world", + } + + def get_score(self, judgment): + matches = self.pattern.findall(judgment) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return [] + elif len(set(matches)) == 1: + try: + return ast.literal_eval(matches[0]) + except SyntaxError: + print(matches[0]) + return [] + else: + return [] + + def pre_process(self, prompt): + conv = [{"role": "system", "content": self.sys_prompt}] + conv.append({"role": "user", "content": prompt}) + return conv + + def post_process(self, judgment): + criteria = self.get_score(judgment=judgment) + return {name: bool(i in criteria) for i, name in self.tags.items()} + + +class CategoryIF(Category): + def __init__(self): + super().__init__() + self.name_tag = "if_v0.1" + self.pattern = re.compile(r"([012345])<\/score>") + self.system_prompt = "You are an AI assistant tasked with determining whether a given user prompt can effectively assess another AI's ability to follow instructions. Your goal is to analyze the prompt and decide if it contains specific, clear instructions that would test an AI's capability to understand and execute directions accurately. Carefully examine the user prompt and consider the following aspects:\n1. Does it contain specific instructions or requirements?\n2. Are there multiple steps or elements the AI needs to address?\n3. Does it ask for a particular format or structure in the response?\n4. Is there a unique or challenging aspect that would test the AI's ability to follow directions precisely?\n\nConsider both the content and the structure of the instructions. A good prompt for assessing instruction-following capabilities should have clear, specific directions that can be objectively evaluated. Think about why this prompt does or does not effectively assess an AI's ability to follow instructions. Consider both the strengths and weaknesses of the prompt in this regard. Output your verdict as a score from 0 to 5:\n0 = Does not evaluate instruction-following ability.\n1 = Ineffective at evaluating instruction-following ability.\n2 = Somewhat effective at evaluating instruction-following ability.\n3 = Effective at evaluating simple instruction-following ability.\n4 = Effective at evaluating more complex instruction-following ability.\n5 = Effective at evaluating advanced instruction-following ability.\n\nPresent your score in the following format:\n[Your score from 0 to 5].\nDo NOT explain." + self.prompt_template = "{PROMPT}" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return int(matches[0]) + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return { + "if": bool(score >= 4) if score else False, + "score": score, + } + + +class CategoryMath(Category): + def __init__(self): + super().__init__() + self.name_tag = "math_v0.1" + self.pattern = re.compile(r"(\w+)<\/decision>") + self.system_prompt = 'You are tasked with determining whether a given user prompt requires an AI assistant to solve a math problem and apply mathematical logic and reasoning.\n\nCarefully analyze the user prompt and consider whether it requires mathematical problem-solving skills to answer correctly. Think about the following aspects:\n\n1. Does it require the application of a specific mathematical concept or formula?\n2. Does the prompt involve numerical calculations or algebraic manipulation or logical reasoning?\n3. Is there a clear mathematical problem to be solved?\n4. Would answering this prompt demonstrate proficiency in a specific area in mathematics?\n\nOutput your verdict in the following format:"\n[yes/no]\n". Do NOT explain.' + self.prompt_template = "\n{PROMPT}\n" + + def get_score(self, judgment): + matches = self.pattern.findall(judgment.replace("\n", "").lower()) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return None + elif len(set(matches)) == 1: + return matches[0] + else: + return None + + def pre_process(self, prompt): + args = {"PROMPT": prompt} + conv = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.prompt_template.format(**args)}, + ] + return conv + + def post_process(self, judgment): + score = self.get_score(judgment=judgment) + return {"math": bool(score == "yes") if score else False} diff --git a/fastchat/serve/monitor/classify/config.yaml b/fastchat/serve/monitor/classify/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63da566c5f767ed8f9914b2fcefd43e115144fd8 --- /dev/null +++ b/fastchat/serve/monitor/classify/config.yaml @@ -0,0 +1,25 @@ +# Yaml config file for category classification + +input_file: null # json +cache_file: null # json +output_file: null # json line + +convert_to_json: True + +task_name: + - criteria_v0.1 + - if_v0.1 + - math_v0.1 + +model_name: null +name: llama-3-70b-instruct +endpoints: + - api_base: null + api_key: null +parallel: 50 +temperature: 0.0 +max_token: 512 + +max_retry: 2 +retry_sleep: 10 +error_output: $ERROR$ \ No newline at end of file diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0471a1f1fb4713e19130ee3f9aa55be7a5b266 --- /dev/null +++ b/fastchat/serve/monitor/classify/label.py @@ -0,0 +1,309 @@ +import argparse +import json +import pandas as pd +import os +import time +import concurrent.futures +import tqdm +import yaml +import random +import threading +import orjson + +from category import Category + + +LOCK = threading.RLock() + +TASKS = None +CACHE_DICT = None +OUTPUT_DICT = None + +# API setting constants +API_MAX_RETRY = None +API_RETRY_SLEEP = None +API_ERROR_OUTPUT = None + + +# load config args from config yaml files +def make_config(config_file: str) -> dict: + config_kwargs = {} + with open(config_file, "r") as f: + config_kwargs = yaml.load(f, Loader=yaml.SafeLoader) + return config_kwargs + + +def get_endpoint(endpoint_list): + if endpoint_list is None: + return None + assert endpoint_list is not None + # randomly pick one + api_dict = random.choices(endpoint_list)[0] + return api_dict + + +def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None): + import openai + + if api_dict: + client = openai.OpenAI( + base_url=api_dict["api_base"], + api_key=api_dict["api_key"], + ) + else: + client = openai.OpenAI() + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + # print(messages) + completion = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, + ) + output = completion.choices[0].message.content + # print(output) + break + except openai.RateLimitError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.BadRequestError as e: + print(messages) + print(type(e), e) + break + except openai.APIConnectionError as e: + print(messages) + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.InternalServerError as e: + print(messages) + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except Exception as e: + print(type(e), e) + break + + return output + + +def get_answer( + question: dict, + model_name: str, + max_tokens: int, + temperature: float, + answer_file: str, + api_dict: dict, + categories: list, + testing: bool, +): + if "category_tag" in question: + category_tag = question["category_tag"] + else: + category_tag = {} + + output_log = {} + + for category in categories: + conv = category.pre_process(question["prompt"]) + output = chat_completion_openai( + model=model_name, + messages=conv, + temperature=temperature, + max_tokens=max_tokens, + api_dict=api_dict, + ) + # Dump answers + category_tag[category.name_tag] = category.post_process(output) + + if testing: + output_log[category.name_tag] = output + + question["category_tag"] = category_tag + if testing: + question["output_log"] = output_log + + question.drop(["prompt", "uid", "required_tasks"], inplace=True) + + with LOCK: + with open(answer_file, "a") as fout: + fout.write(json.dumps(question.to_dict()) + "\n") + + +def category_merge(row): + id = row["uid"] + input_category = row["category_tag"] if "category_tag" in row else {} + cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} + output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} + + # tries to fill in missing categories using cache first, then output + for name in TASKS: + if name not in input_category: + if name in cache_category: + input_category[name] = cache_category[name] + continue + if name in output_category: + input_category[name] = output_category[name] + + return input_category + + +def find_required_tasks(row): + id = row["uid"] + input_category = row["category_tag"] if "category_tag" in row else {} + cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} + output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} + + return [ + name + for name in TASKS + if not ( + name in input_category or name in cache_category or name in output_category + ) + ] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--testing", action="store_true") + args = parser.parse_args() + + enter = input( + "Make sure your config file is properly configured. Press enter to continue." + ) + if not enter == "": + exit() + + config = make_config(args.config) + + API_MAX_RETRY = config["max_retry"] + API_RETRY_SLEEP = config["retry_sleep"] + API_ERROR_OUTPUT = config["error_output"] + + categories = [Category.create_category(name) for name in config["task_name"]] + TASKS = config["task_name"] + print( + f"Following categories will be labeled:\n{[category.name_tag for category in categories]}" + ) + + print("loading input data (might take min)") + with open(config["input_file"], "rb") as f: + data = orjson.loads(f.read()) + input_data = pd.DataFrame(data) + + # much faster than pd.apply + input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) + assert len(input_data) == len(input_data.uid.unique()) + print(f"{len(input_data)}# of input data just loaded") + + if config["cache_file"]: + print("loading cache data") + with open(config["cache_file"], "rb") as f: + data = orjson.loads(f.read()) + cache_data = pd.DataFrame(data) + cache_data["uid"] = cache_data.question_id.map(str) + cache_data.tstamp.map(str) + assert len(cache_data) == len(cache_data.uid.unique()) + + print(f"{len(cache_data)}# of cache data just loaded") + + assert "category_tag" in cache_data.columns + cache_dict = cache_data[["uid", "category_tag"]].set_index("uid") + print("finalizing cache_dict (should take less than 30 sec)") + CACHE_DICT = cache_dict.to_dict("index") + else: + CACHE_DICT = {} + + if os.path.isfile(config["output_file"]): + print("loading existing output") + output_data = pd.read_json(config["output_file"], lines=True) + output_data["uid"] = output_data.question_id.map(str) + output_data.tstamp.map( + str + ) + assert len(output_data) == len(output_data.uid.unique()) + + print(f"{len(output_data)}# of existing output just loaded") + + assert "category_tag" in output_data.columns + output_dict = output_data[["uid", "category_tag"]].set_index("uid") + print("finalizing output_dict (should take less than 30 sec)") + OUTPUT_DICT = output_dict.to_dict("index") + else: + OUTPUT_DICT = {} + + print( + "finding tasks needed to run... (should take around 1 minute or less on large dataset)" + ) + input_data["required_tasks"] = input_data.apply(find_required_tasks, axis=1) + + not_labeled = input_data[input_data.required_tasks.map(lambda x: len(x) > 0)].copy() + + print(f"{len(not_labeled)} # of conversations needs to be labeled") + for name in TASKS: + print( + f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}" + ) + + not_labeled["prompt"] = not_labeled.conversation_a.map( + lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + ) + not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500]) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=config["parallel"] + ) as executor: + futures = [] + for index, row in tqdm.tqdm(not_labeled.iterrows()): + future = executor.submit( + get_answer, + row, + config["model_name"], + config["max_token"], + config["temperature"], + config["output_file"], + get_endpoint(config["endpoints"]), + [ + category + for category in categories + if category.name_tag in row["required_tasks"] + ], + args.testing, + ) + futures.append(future) + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + future.result() + + if config["convert_to_json"]: + # merge two data frames, but only take the fields from the cache data to overwrite the input data + merge_columns = [category.name_tag for category in categories] + print(f"Columns to be merged:\n{merge_columns}") + + input_data["uid"] = input_data.question_id.map(str) + input_data.tstamp.map(str) + assert len(input_data) == len(input_data.uid.unique()) + + # fastest way to merge + assert os.path.isfile(config["output_file"]) + print("reading output file...") + temp = pd.read_json(config["output_file"], lines=True) + temp["uid"] = temp.question_id.map(str) + temp.tstamp.map(str) + assert len(temp) == len(temp.uid.unique()) + + assert "category_tag" in temp.columns + output_dict = temp[["uid", "category_tag"]].set_index("uid") + print("finalizing output_dict (should take less than 30 sec)") + OUTPUT_DICT = output_dict.to_dict("index") + + print("begin merging (should take around 1 minute or less on large dataset)") + input_data["category_tag"] = input_data.apply(category_merge, axis=1) + print("merge completed") + + final_data = input_data.drop( + columns=["prompt", "uid", "required_tasks"], errors="ignore" + ) + final_data.to_json( + config["output_file"][:-1], orient="records", indent=4, force_ascii=False + ) diff --git a/fastchat/serve/monitor/clean_battle_data.py b/fastchat/serve/monitor/clean_battle_data.py new file mode 100644 index 0000000000000000000000000000000000000000..270f981ccf62eb822d823ed299d864f6533fad17 --- /dev/null +++ b/fastchat/serve/monitor/clean_battle_data.py @@ -0,0 +1,423 @@ +""" +Clean chatbot arena battle log. + +Usage: +python3 clean_battle_data.py --mode conv_release +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm +from multiprocessing import Pool +import tiktoken +from collections import Counter +import shortuuid + +from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS +from fastchat.utils import detect_language + + +VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] +IDENTITY_WORDS = [ + "vicuna", + "lmsys", + "koala", + "uc berkeley", + "open assistant", + "laion", + "chatglm", + "chatgpt", + "gpt-4", + "openai", + "anthropic", + "claude", + "bard", + "palm", + "lamda", + "google", + "gemini", + "llama", + "qianwan", + "qwen", + "alibaba", + "mistral", + "zhipu", + "KEG lab", + "01.AI", + "AI2", + "Tülu", + "Tulu", + "deepseek", + "hermes", + "cohere", + "DBRX", + "databricks", +] + +ERROR_WORDS = [ + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.", + "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.", + "API REQUEST ERROR. Please increase the number of max tokens.", + "**API REQUEST ERROR** Reason: The response was blocked.", + "**API REQUEST ERROR**", +] + +UNFINISHED_WORDS = [ + "▌", + '', +] + +for i in range(len(IDENTITY_WORDS)): + IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower() + +for i in range(len(ERROR_WORDS)): + ERROR_WORDS[i] = ERROR_WORDS[i].lower() + + +def remove_html(raw): + if isinstance(raw, str) and raw.startswith("

"): + return raw[raw.find(": ") + 2 : -len("

\n")] + return raw + + +def to_openai_format(messages): + roles = ["user", "assistant"] + ret = [] + for i, x in enumerate(messages): + ret.append({"role": roles[i % 2], "content": x[1]}) + return ret + + +def replace_model_name(old_name, tstamp): + replace_dict = { + "bard": "palm-2", + "claude-v1": "claude-1", + "claude-instant-v1": "claude-instant-1", + "oasst-sft-1-pythia-12b": "oasst-pythia-12b", + "claude-2": "claude-2.0", + "StripedHyena-Nous-7B": "stripedhyena-nous-7b", + "gpt-4-turbo": "gpt-4-1106-preview", + "gpt-4-0125-assistants-api": "gpt-4-turbo-browsing", + } + if old_name in ["gpt-4", "gpt-3.5-turbo"]: + if tstamp > 1687849200: + return old_name + "-0613" + else: + return old_name + "-0314" + if old_name in replace_dict: + return replace_dict[old_name] + return old_name + + +def read_file(filename): + data = [] + for retry in range(5): + try: + # lines = open(filename).readlines() + for l in open(filename): + row = json.loads(l) + if row["type"] in VOTES: + data.append(row) + break + except FileNotFoundError: + time.sleep(2) + return data + + +def read_file_parallel(log_files, num_threads=16): + data_all = [] + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def process_data( + data, + exclude_model_names, + sanitize_ip, + ban_ip_list, +): + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + convert_type = { + "leftvote": "model_a", + "rightvote": "model_b", + "tievote": "tie", + "bothbad_vote": "tie (bothbad)", + } + + all_ips = dict() + + count_dict = { + "anony": 0, + "invalid": 0, + "leaked_identity": 0, + "banned": 0, + "error": 0, + "unfinished": 0, + "none_msg": 0, + "exclude_model": 0, + } + count_leak = {} + + battles = [] + for row in data: + flag_anony = False + flag_leaked_identity = False + flag_error = False + flag_unfinished = False + flag_none_msg = False + + if row["models"][0] is None or row["models"][1] is None: + continue + + # Resolve model names + models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] + if "model_name" in row["states"][0]: + models_hidden = [ + row["states"][0]["model_name"], + row["states"][1]["model_name"], + ] + if models_hidden[0] is None: + models_hidden = models_public + else: + models_hidden = models_public + + if (models_public[0] == "" and models_public[1] != "") or ( + models_public[1] == "" and models_public[0] != "" + ): + count_dict["invalid"] += 1 + continue + + if models_public[0] == "" or models_public[0] == "Model A": + flag_anony = True + models = models_hidden + else: + flag_anony = False + models = models_public + if ( + models_hidden[0] not in models_public[0] + or models_hidden[1] not in models_public[1] + ): + count_dict["invalid"] += 1 + continue + + # Detect langauge + state = row["states"][0] + if state["offset"] >= len(state["messages"]): + count_dict["invalid"] += 1 + continue + lang_code = detect_language(state["messages"][state["offset"]][1]) + + # Drop conversations if the model names are leaked + messages = "" + for i in range(2): + state = row["states"][i] + for _, (role, msg) in enumerate(state["messages"][state["offset"] :]): + if msg: + messages += msg.lower() + else: + flag_none_msg = True + + for word in IDENTITY_WORDS: + if word in messages: + if word not in count_leak: + count_leak[word] = 0 + count_leak[word] += 1 + flag_leaked_identity = True + break + + for word in ERROR_WORDS: + if word in messages: + flag_error = True + break + + for word in UNFINISHED_WORDS: + if word in messages: + flag_unfinished = True + break + + if flag_none_msg: + count_dict["none_msg"] += 1 + continue + if flag_leaked_identity: + count_dict["leaked_identity"] += 1 + continue + if flag_error: + count_dict["error"] += 1 + continue + if flag_unfinished: + count_dict["unfinished"] += 1 + continue + + # Replace bard with palm + models = [replace_model_name(m, row["tstamp"]) for m in models] + # Exclude certain models + if exclude_model_names and any(x in exclude_model_names for x in models): + count_dict["exclude_model"] += 1 + continue + + question_id = row["states"][0]["conv_id"] + conversation_a = to_openai_format( + row["states"][0]["messages"][row["states"][0]["offset"] :] + ) + conversation_b = to_openai_format( + row["states"][1]["messages"][row["states"][1]["offset"] :] + ) + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": shortuuid.uuid()} + all_ips[ip]["count"] += 1 + if sanitize_ip: + user_id = f"{all_ips[ip]['sanitized_id']}" + else: + user_id = f"{all_ips[ip]['ip']}" + + if ban_ip_list is not None and ip in ban_ip_list: + count_dict["banned"] += 1 + continue + + if flag_anony: + count_dict["anony"] += 1 + + for conv in conversation_a: + conv["num_tokens"] = len( + encoding.encode(conv["content"], allowed_special="all") + ) + for conv in conversation_b: + conv["num_tokens"] = len( + encoding.encode(conv["content"], allowed_special="all") + ) + + # Save the results + battles.append( + dict( + question_id=question_id, + model_a=models[0], + model_b=models[1], + winner=convert_type[row["type"]], + judge=f"arena_user_{user_id}", + conversation_a=conversation_a, + conversation_b=conversation_b, + turn=len(conversation_a) // 2, + anony=flag_anony, + language=lang_code, + tstamp=row["tstamp"], + ) + ) + return battles, count_dict, count_leak, all_ips + + +def clean_battle_data( + log_files, + exclude_model_names, + ban_ip_list=None, + sanitize_ip=False, + anony_only=False, + num_threads=16, +): + data = read_file_parallel(log_files, num_threads=16) + + battles = [] + count_dict = {} + count_leak = {} + all_ips = {} + with Pool(num_threads) as p: + # split data into chunks + chunk_size = len(data) // min(100, len(data)) + data_chunks = [ + data[i : i + chunk_size] for i in range(0, len(data), chunk_size) + ] + + args_list = [ + (data_chunk, exclude_model_names, sanitize_ip, ban_ip_list) + for data_chunk in data_chunks + ] + ret_all = list(tqdm(p.starmap(process_data, args_list), total=len(data_chunks))) + + for ret in ret_all: + sub_battles, sub_count_dict, sub_count_leak, sub_all_ips = ret + battles.extend(sub_battles) + count_dict = dict(Counter(count_dict) + Counter(sub_count_dict)) + count_leak = dict(Counter(count_leak) + Counter(sub_count_leak)) + for ip in sub_all_ips: + if ip not in all_ips: + all_ips[ip] = sub_all_ips[ip] + else: + all_ips[ip]["count"] += sub_all_ips[ip]["count"] + battles.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = battles[-1]["tstamp"] + + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + print(f"#votes: {len(data)}") + print(count_dict) + print(f"#battles: {len(battles)}, #anony: {count_dict['anony']}") + print(f"last-updated: {last_updated_datetime}") + print(f"leaked_identity: {count_leak}") + + if ban_ip_list is not None: + for ban_ip in ban_ip_list: + if ban_ip in all_ips: + del all_ips[ban_ip] + print("Top 30 IPs:") + print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30]) + return battles + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + parser.add_argument( + "--mode", type=str, choices=["simple", "conv_release"], default="simple" + ) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--sanitize-ip", action="store_true", default=False) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None + + battles = clean_battle_data( + log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip + ) + last_updated_tstamp = battles[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + if args.mode == "simple": + for x in battles: + for key in [ + "conversation_a", + "conversation_b", + "question_id", + ]: + del x[key] + print("Samples:") + for i in range(4): + print(battles[i]) + output = f"clean_battle_{cutoff_date}.json" + elif args.mode == "conv_release": + new_battles = [] + for x in battles: + if not x["anony"]: + continue + for key in []: + del x[key] + new_battles.append(x) + battles = new_battles + output = f"clean_battle_conv_{cutoff_date}.json" + + with open(output, "w", encoding="utf-8", errors="replace") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/fastchat/serve/monitor/clean_chat_data.py b/fastchat/serve/monitor/clean_chat_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2bda0e2c3a5242f65fab00e05b1199afdf864cc9 --- /dev/null +++ b/fastchat/serve/monitor/clean_chat_data.py @@ -0,0 +1,171 @@ +""" +Clean chatbot arena chat log. + +Usage: +python3 clean_chat_data.py +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import NUM_SERVERS +from fastchat.serve.monitor.clean_battle_data import ( + to_openai_format, + replace_model_name, +) +from fastchat.utils import detect_language + + +NETWORK_ERROR_MSG = ( + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() +) + + +def get_log_files(max_num_files=None): + dates = [] + for month in range(4, 12): + for day in range(1, 33): + dates.append(f"2023-{month:02d}-{day:02d}") + + filenames = [] + for d in dates: + for i in range(NUM_SERVERS): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + # filenames = list(reversed(filenames)) + filenames = filenames[-max_num_files:] + return filenames + + +def clean_chat_data(log_files, action_type): + raw_data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + if row["type"] == action_type: + raw_data.append(row) + + all_models = set() + all_ips = dict() + chats = [] + ct_invalid_conv_id = 0 + ct_invalid = 0 + ct_network_error = 0 + for row in raw_data: + try: + if action_type in ["chat", "upvote", "downvote"]: + state = row["state"] + model = row["model"] + elif action_type == "leftvote": + state = row["states"][0] + model = row["states"][0]["model_name"] + elif action_type == "rightvote": + state = row["states"][1] + model = row["states"][1]["model_name"] + conversation_id = state["conv_id"] + except KeyError: + ct_invalid_conv_id += 1 + continue + + if conversation_id is None: + ct_invalid_conv_id += 1 + continue + + conversation = to_openai_format(state["messages"][state["offset"] :]) + if not isinstance(model, str): + ct_invalid += 1 + continue + model = replace_model_name(model, row["tstamp"]) + + try: + lang_code = detect_language(state["messages"][state["offset"]][1]) + except IndexError: + ct_invalid += 1 + continue + + if not all(isinstance(x["content"], str) for x in conversation): + ct_invalid += 1 + continue + + messages = "".join([x["content"] for x in conversation]).lower() + if NETWORK_ERROR_MSG in messages: + ct_network_error += 1 + continue + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = len(all_ips) + user_id = all_ips[ip] + + chats.append( + dict( + conversation_id=conversation_id, + model=model, + conversation=conversation, + turn=len(conversation) // 2, + language=lang_code, + user_id=user_id, + tstamp=row["tstamp"], + ) + ) + + all_models.update([model]) + + chats.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = chats[-1]["tstamp"] + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # Deduplication + dedup_chats = [] + visited_conv_ids = set() + for i in reversed(range(len(chats))): + if chats[i]["conversation_id"] in visited_conv_ids: + continue + visited_conv_ids.add(chats[i]["conversation_id"]) + dedup_chats.append(chats[i]) + + print( + f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" + ) + print( + f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" + ) + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + return list(reversed(dedup_chats)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--action-type", type=str, default="chat") + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + chats = clean_chat_data(log_files, args.action_type) + last_updated_tstamp = chats[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + output = f"clean_{args.action_type}_conv_{cutoff_date}.json" + with open(output, "w") as fout: + json.dump(chats, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/fastchat/serve/monitor/code_tagger.py b/fastchat/serve/monitor/code_tagger.py new file mode 100644 index 0000000000000000000000000000000000000000..12eeaed4b25e0f7b4b6b31d562de905d4ac6a8d2 --- /dev/null +++ b/fastchat/serve/monitor/code_tagger.py @@ -0,0 +1,180 @@ +import re +import json +import argparse +import multiprocessing as mp + +import nltk +from tqdm import tqdm +from nltk.tokenize import word_tokenize + + +def is_code_conversation(text: str) -> tuple[bool, list[str]]: + """Check if the text is a code conversation""" + + if "```plaintext" in text: + lines = text.split("\n") + line1_idx = [idx for idx, line in enumerate(lines) if "```plaintext" in line][0] + line2_idx = [ + line1_idx + 1 + idx + for idx, line in enumerate(lines) + if "```" in line[line1_idx + 1 :] + ] + if line2_idx: + line2_idx = line2_idx[0] + text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :]) + else: + text = "\n".join(lines[:line1_idx]) + return is_code_conversation(text) + + if "```markdown" in text: + otext = text + lines = text.split("\n") + line1_idx = [idx for idx, line in enumerate(lines) if "```markdown" in line][0] + line2_idx = [ + line1_idx + 1 + idx + for idx, line in enumerate(lines) + if "```" in line[line1_idx + 1 :] + ] + if line2_idx: + line2_idx = line2_idx[0] + text = "\n".join(lines[:line1_idx]) + "\n".join(lines[line2_idx + 1 :]) + else: + text = "\n".join(lines[:line1_idx]) + return is_code_conversation(text) + + if "ascii art" in text.lower(): + return False, [] + + # 1. Check for code formatting + if re.search(r"```", text): + return True, ["backticks"] + + # Tokenize the text + tokens = word_tokenize(text) + tokens = [token.lower() for token in tokens] + + # 2. Check for programming concepts + concepts = ["git", "github", "pull request", "dataframe", "nginx", "pip"] + if any(concept in tokens for concept in concepts): + matched_concepts = list(set(tokens).intersection(set(concepts))) + return True, matched_concepts + + # 3. Check for programming language name + languages = [ + "python", + "c++", + "cpp", + "java", + "javascript", + "typescript", + "html", + "css", + "sql", + "bash", + "powershell", + "matlab", + "golang", + "linux", + "ubuntu", + ] + if any(language in tokens for language in languages): + matched_languages = list(set(tokens).intersection(set(languages))) + return True, matched_languages + + # 4. Programming concept substrings + strings = [ + "import pandas", + "import numpy", + "import torch", + "jax", + "tensorflow", + "pytorch", + "keras", + "scikit-learn", + "sklearn", + " apt-get ", + ] + found_array = [string in text for string in strings] + if any(found_array): + matched_strings = [ + string for string, found in zip(strings, found_array) if found + ] + return True, matched_strings + + # 5. Programming concept regexes + regexes = [ + r"from \w+ import \w+", + r"conda install \w+", + r"pip install -r \w+", + r"conda install -c \w+ \w+", + r"#include <\w+>", + r"import \w+ as \w+", + r"#include \"\w+\.h\"", + ] + found_array = [re.search(regex, text) for regex in regexes] + if any(found_array): + matched_regexes = [regex for regex, found in zip(regexes, found_array) if found] + return True, matched_regexes + + return False, [] + + +def check_code_conv(conv) -> tuple[bool, list[str]]: + """Check if the conversation is a code conversation""" + for _, msg in enumerate(conv): + content = msg["content"] + if not isinstance(content, str): + continue + is_code_conv_res = is_code_conversation(content) + if is_code_conv_res[0]: + return is_code_conv_res + return False, [] + + +def check_conv_row(conv_row): + check_a, code_a = check_code_conv(conv_row["conversation_a"]) + check_b, code_b = check_code_conv(conv_row["conversation_b"]) + + return check_a or check_b, code_a + code_b + + +def process_battle_file(battle_file_path: str, n_cpus: int): + with open(battle_file_path, "r") as f: + data = json.load(f) + + with mp.Pool(n_cpus) as pool: + tagged_data = list(tqdm(pool.imap(check_conv_row, data), total=len(data))) + + output_data = [row for row, (is_code, _) in zip(data, tagged_data) if is_code] + + return output_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--output-clean-battle-file", type=str, default=None) + parser.add_argument("--n-cpus", type=int, default=-1) + + args = parser.parse_args() + + if args.output_clean_battle_file is None: + args.output_clean_battle_file = args.clean_battle_file + + if args.n_cpus == -1: + args.n_cpus = mp.cpu_count() + + print( + f"Processing {args.clean_battle_file} and saving to {args.output_clean_battle_file} with {args.n_cpus} cpus" + ) + + output_data = process_battle_file(args.clean_battle_file, args.n_cpus) + + with open(args.output_clean_battle_file, "w") as f: + json.dump(output_data, f, indent=4) + + print(f"Total code conversations: {len(output_data)}") + print("Done!") + + with open(args.output_clean_battle_file, "r") as f: + data = json.load(f) diff --git a/fastchat/serve/monitor/criteria_labeling.py b/fastchat/serve/monitor/criteria_labeling.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b030d45bc7faff2a8c85e5923f701a94d4cc4 --- /dev/null +++ b/fastchat/serve/monitor/criteria_labeling.py @@ -0,0 +1,214 @@ +import argparse +import json +import pandas as pd +import os +import re +import ast +import time +import concurrent.futures +import tqdm +import random +import threading + +LOCK = threading.RLock() + +## Configs +SYSTEM_PROMPT = "Your task is to evaluate how well the following input prompts can assess the capabilities of advanced AI assistants.\n\nFor the input prompt, please analyze it based on the following 7 criteria.\n1. Specificity: Does the prompt ask for a specific output, such as code, a mathematical solution, a logical simplification, a problem-solving strategy, or a hardware setup recommendation? This specificity allows the AI to demonstrate its ability to understand and generate precise responses.\n2. Domain Knowledge: Does the prompt cover a specific domain, such as programming, mathematics, logic, problem-solving, or hardware setup? Prompts spanning a range of topics test the AI's breadth of knowledge and its ability to apply that knowledge to different domains.\n3. Complexity: Does the prompt vary in complexity, from straightforward tasks to more complex, multi-step problems? This allows evaluators to assess the AI's capability to handle problems of varying difficulty.\n4. Problem-Solving Skills: Does the prompt directly involves the AI to demonstrate active problem-solving skills, such systemically coming up with a solution for a specific setup instead of regurgitating an existing fact? This tests the AI's ability to apply logical reasoning and provide practical solutions.\n5. Creativity: Does the prompt involve a level of creativity in approaching the problem? This criterion tests the AI's ability to provide tailored solutions that take into account the user's specific needs and limitations.\n6. Technical Accuracy: Does the prompt require technical accuracy in the response? This allows evaluators to assess the AI's precision and correctness in technical fields.\n7. Real-world Application: Does the prompt relate to real-world applications, such as setting up a functional system or writing code for a practical use case? This tests the AI's ability to provide practical and actionable information that could be implemented in real-life scenarios.\n\nYou must list the criteria numbers that the prompt satisfies in the format of a Python array. For example, \"[...]\". Do not explain your choice." + +ENDPOINT_INFO = { + "model_name": "META-LLAMA/LLAMA-3-70B-CHAT-HF", + "name": "llama-3-70b-instruct", + "endpoints": [{"api_base": "-", "api_key": "-"}], + "parallel": 8, + "temperature": 0.0, + "max_token": 512, +} # Modify this + +TAGS = { + 1: "specificity", + 2: "domain_knowledge", + 3: "complexity", + 4: "problem_solving", + 5: "creativity", + 6: "technical_accuracy", + 7: "real_world", +} + +# API setting constants +API_MAX_RETRY = 3 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + + +def get_endpoint(endpoint_list): + if endpoint_list is None: + return None + assert endpoint_list is not None + # randomly pick one + api_dict = random.choices(endpoint_list)[0] + return api_dict + + +pattern = re.compile(r"(\[\d(?:\,\s\d)*\])") + + +def get_score(judgment): + matches = pattern.findall(judgment) + matches = [m for m in matches if m != ""] + if len(set(matches)) == 0: + return [] + elif len(set(matches)) == 1: + try: + return ast.literal_eval(matches[0]) + except SyntaxError: + print(matches[0]) + return [] + else: + return [] + + +def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None): + import openai + + if api_dict: + client = openai.OpenAI( + base_url=api_dict["api_base"], + api_key=api_dict["api_key"], + ) + else: + client = openai.OpenAI() + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + # print(messages) + completion = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None, + ) + output = completion.choices[0].message.content + break + except openai.RateLimitError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.BadRequestError as e: + print(messages) + print(type(e), e) + break + except openai.APIConnectionError as e: + print(messages) + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.InternalServerError as e: + print(messages) + print(type(e), e) + time.sleep(1) + except KeyError: + print(type(e), e) + break + + return output + + +def get_answer( + question: dict, + max_tokens: int, + temperature: float, + answer_file: str, + api_dict: dict, +): + conv = [] + conv.append({"role": "system", "content": SYSTEM_PROMPT}) + + conv.append({"role": "user", "content": question["prompt"]}) + output = chat_completion_openai( + model=ENDPOINT_INFO["model_name"], + messages=conv, + temperature=temperature, + max_tokens=max_tokens, + api_dict=api_dict, + ) + + criteria = get_score(output) + + # Dump answers + question["criteria_tag"] = {name: bool(i in criteria) for i, name in TAGS.items()} + question.drop("prompt") + + with LOCK: + with open(answer_file, "a") as fout: + fout.write(json.dumps(question.to_dict()) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--cache-file", type=str, default=None) + parser.add_argument("--output-file", type=str, required=True) + parser.add_argument("--convert-to-json", action="store_true") + args = parser.parse_args() + + print("loading input data (might take min)") + input_data = pd.read_json(args.input_file) + print(f"{len(input_data)}# of input data just loaded") + if args.cache_file: + print("loading cache data") + cache_data = pd.read_json(args.cache_file) + print(f"{len(cache_data)}# of cache data just loaded") + + assert "criteria_tag" in cache_data.columns and len( + cache_data["criteria_tag"].dropna() + ) == len(cache_data) + + not_labeled = input_data[ + ~input_data["question_id"].isin(cache_data["question_id"]) + ].copy() + else: + not_labeled = input_data.copy() + + if os.path.isfile(args.output_file): + print("loading existing output") + output_data = pd.read_json(args.output_file, lines=True) + print(f"{len(output_data)}# of existing output just loaded") + + assert "criteria_tag" in output_data.columns and len( + output_data["criteria_tag"].dropna() + ) == len(output_data) + + not_labeled = not_labeled[ + ~not_labeled["question_id"].isin(output_data["question_id"]) + ] + + print(f"{len(not_labeled)} needs to be labeled") + + not_labeled["prompt"] = not_labeled.conversation_a.map( + lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)]) + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=ENDPOINT_INFO["parallel"] + ) as executor: + futures = [] + for index, row in tqdm.tqdm(not_labeled.iterrows()): + future = executor.submit( + get_answer, + row, + ENDPOINT_INFO["max_token"], + ENDPOINT_INFO["temperature"], + args.output_file, + get_endpoint(ENDPOINT_INFO["endpoints"]), + ) + futures.append(future) + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + future.result() + + if args.convert_to_json: + temp = pd.read_json(args.output_file, lines=True) + temp.to_json( + args.output_file[:-1], orient="records", indent=4, force_ascii=False + ) diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py new file mode 100644 index 0000000000000000000000000000000000000000..8e94cf2756203f207e82cc7f31ff544ecdcc80f0 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + args = parser.parse_args() + + lines = json.load(open(args.input)) + ct_anony_votes = 0 + all_users = set() + all_models = set() + for l in lines: + if not l["anony"]: + continue + all_users.add(l["judge"]) + all_models.add(l["model_a"]) + all_models.add(l["model_b"]) + ct_anony_votes += 1 + + print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") + print(f"#model: {len(all_models)}") diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..6d12d7c652bc02bb7b5c9f65bce0e1644f739c1b --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py @@ -0,0 +1,155 @@ +""" +Filter conversations for release. + +Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts): + return TypeCode.TOO_SHORT + + if all(x in frequent_prompts for x in user_prompts): + return TypeCode.TOO_FREQUENT + + for msg in messages: + msg = msg.lower() + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + for key in ["model_a", "model_b"]: + if conv[key] in ["vicuna-33b", "mpt-30b-chat"]: + return TypeCode.BLOCKED_MODEL + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + + # Count frequency + for conv in convs: + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key] if row["role"] == "user"] + for msg in messages: + if not isinstance(msg, str): + continue + msg = msg.lower().strip() + frequency[msg] += 1 + + keys = list(frequency.keys()) + keys.sort(key=lambda x: -frequency[x]) + frequent_prompts = keys[:10] + frequent_prompts = set(frequent_prompts) + frequent_prompts.add("") + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + new_convs = [] + for conv in tqdm(convs): + type_code = detect_type(conv) + + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + # random.seed(0) + # random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".out.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py new file mode 100644 index 0000000000000000000000000000000000000000..5a88209bfcb58cb2131ce94d6eba03c899e74a0a --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--tag-file", type=str) + args = parser.parse_args() + + # build index + objs = json.load(open(args.tag_file)) + new_field_dict = {} + for obj in objs: + new_field_dict[obj["question_id"]] = obj["toxic_chat"] + + objs = json.load(open(args.input)) + for obj in objs: + obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] + + output = args.input.replace(".json", "_added.json") + with open(output, "w") as fout: + json.dump(objs, fout, indent=2, ensure_ascii=False) diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd78b71e95a3034bf3440aee3557a38426d0244 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(0) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e37aadcea65df7ca605369b88c068aa57c8f35f2 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py @@ -0,0 +1,9 @@ +""" +Upload to huggingface. +""" +import json +from datasets import Dataset, DatasetDict, load_dataset + +objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) +data = Dataset.from_list(objs) +data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py new file mode 100644 index 0000000000000000000000000000000000000000..a7084207309907dcb8fa37eccf55fd2a6b62ca48 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py @@ -0,0 +1,13 @@ +import requests + +headers = {"authorization": "Bearer hf_XXX"} + +url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending" +a = requests.get(url, headers=headers) + +for u in a.json(): + user = u["user"]["user"] + url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant" + ret = requests.post(url, headers=headers, json={"user": user}) + print(user, ret.status_code) + assert ret.status_code == 200 diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..97abaaa0df053c93c3adb655f1b5c41af0aab00d --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py @@ -0,0 +1,119 @@ +""" +From colab: +https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +import kaleido +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + +import plotly.io as pio + +pio.kaleido.scope.mathjax = None + +parser = argparse.ArgumentParser() +parser.add_argument("--in-file", type=str, required=True) +parser.add_argument("--scale", type=int, required=True) +args = parser.parse_args() + +filename = args.in_file +scale = args.scale +convs = json.load(open(filename)) +df = pd.DataFrame(convs) +df + +print(f"#ips: {df['user_id'].nunique() * scale}") +print(f"#models: {df['model'].nunique()}") +print(f"#language: {df['language'].nunique()}") +print(f"#turns: {df['turn'].mean()}") + +model_counts = df["model"].value_counts() * scale +# print("model counts", model_counts) +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("model_count.pdf") + + +model_counts = df["language"].value_counts().head(25) * scale +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("language_count.pdf") + +chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") + for x in df["tstamp"] +] + + +def to_remove(x): + for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: + if d in x: + return True + return False + + +chat_dates = [x for x in chat_dates if not to_remove(x)] + +chat_dates_counts = pd.value_counts(chat_dates) * scale +print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") + +fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) +fig.update_layout( + xaxis_title="Dates", + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("daily_conversation_count.pdf") + +import transformers + +tokenizer = transformers.AutoTokenizer.from_pretrained( + "lmsys/vicuna-7b-v1.5", use_fast=False +) + +prompts = [] +responses = [] +for conv in df["conversation"]: + for row in conv: + if row["role"] == "user": + prompts.append(row["content"]) + else: + responses.append(row["content"]) + +print(f"#prompts: {len(prompts)}") +print(f"#responses: {len(responses)}") + + +prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] +print() +print(f"mean prompt len: {np.mean(prompt_lens):.2f}") + +response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] +print() +print(f"mean response len: {np.mean(response_lens):.2f}") diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccde1ca57546acf5d1131cae14a499f1228a02c --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py @@ -0,0 +1,148 @@ +""" +Filter conversations for release. + +Dependency: +pip install opencc-python-reimplementedpip install opencc-python-reimplemented + +Usage: +python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm +import opencc + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + +cc_converter = opencc.OpenCC("t2s") + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b", "conversation"]: + if key not in conv: + continue + + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + if len(messages) == 0: + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + + for msg in messages: + msg = cc_converter.convert(msg.lower()) + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + blocked_words = [cc_converter.convert(w) for w in blocked_words] + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + type_codes = [] + with ProcessPoolExecutor() as executor: + for result in tqdm(executor.map(detect_type, convs), total=len(convs)): + type_codes.append(result) + + new_convs = [] + for conv, type_code in zip(convs, type_codes): + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + random.seed(42) + random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".s1.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e368e92a1dcf260ecb5b175b77e85c6971809a3c --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py @@ -0,0 +1,27 @@ +import argparse +import json + +from tqdm import tqdm +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Delete some fileds + for c in convs: + del c["tstamp"] + del c["user_id"] + + # Write + print(f"#out conv: {len(convs)}") + out_file = args.in_file.replace(".json", ".s2.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(convs, fout, indent=2, ensure_ascii=False) diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md new file mode 100644 index 0000000000000000000000000000000000000000..4c439731f6aee43bd29e1a65576c5ae04ff59cfa --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md @@ -0,0 +1,23 @@ +``` +export BASE=clean_conv_20230809_100k_pii +export SCALE=10 + +# filter words +python3 filter_bad_conv.py --in $BASE.json + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE + +# Copy figures +scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" . +``` + diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..18bef5f1962384d80f174aa22a7b6dcc867fe7c0 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py @@ -0,0 +1,45 @@ +import argparse +import json +import time + +from tqdm import tqdm + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json" + # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json" + in_file = args.in_file + tic = time.time() + + # Load tags + print("Load tags...") + tag_data = json.load(open(tag_file)) + tag_dict = {} + for c in tqdm(tag_data): + tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]] + print(f"elapsed: {time.time() - tic:.2f} s") + + # Append to input_file + print("Load inputs...") + input_data = json.load(open(in_file)) + for c in tqdm(input_data): + cid = c["conversation_id"] + if cid in tag_dict: + c["openai_moderation"] = tag_dict[cid] + else: + print(f"missing tag for conv {cid}") + exit() + print(f"elapsed: {time.time() - tic:.2f} s") + + # Write output + print("Write outputs...") + out_file = in_file.replace(".json", ".with_tag.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(input_data, fout, indent=2, ensure_ascii=False) + print(f"elapsed: {time.time() - tic:.2f} s") diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..5bae9fbad221c57eba8f2cf5b7eb2779a6f040a8 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh @@ -0,0 +1,18 @@ +export BASE=clean_conv_20230809_1.5M_pii +#export BASE=clean_conv_20230809_100k_pii +export SCALE=1 + +# Filter words +python3 filter_bad_conv.py --in $BASE.json --sample 1000000 + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# Upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6da455fc7bf8af1ce473f80440bff280c9366e --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(42) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..41d0fbdb59b4c7dc8385bef87a1bf0c8ea6e7401 --- /dev/null +++ b/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py @@ -0,0 +1,17 @@ +""" +Upload to huggingface. +""" +import argparse +import json +from datasets import Dataset, DatasetDict, load_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + objs = json.load(open(args.in_file)) + print(f"#convs: {len(objs)}") + data = Dataset.from_list(objs) + data.push_to_hub("lmsys/lmsys-chat-1m", private=True) diff --git a/fastchat/serve/monitor/deduplication.py b/fastchat/serve/monitor/deduplication.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4b5765d23549371d92c32b73d951ca58533844 --- /dev/null +++ b/fastchat/serve/monitor/deduplication.py @@ -0,0 +1,85 @@ +import os +import json +import pandas as pd +import ast + +import matplotlib.pyplot as plt +from matplotlib import rcParams + +import argparse +import seaborn as sns +from tqdm import tqdm +import matplotlib.pyplot as plt + +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", type=str, default="output") + parser.add_argument("--model", type=str, default=None) + parser.add_argument("--input_file", type=str, required=True) + parser.add_argument("--percentile", type=float, default=0.9999) + args = parser.parse_args() + output_dir = args.output_dir + input_file = args.input_file + + with open(input_file) as f: + data = json.load(f) + + os.makedirs(output_dir, exist_ok=True) + + # Preprocessing + all_convs_new = [] + convs = [] + for row in data: + conv = "" + for turns in row["conversation_a"]: + if turns["role"] == "user": + conv += f"{turns['content']}\n" + + convs.append(conv[:10000]) + row["post_process_conv"] = conv[:10000] + all_convs_new.append(row) + + df = pd.DataFrame(all_convs_new) + print("Number of conversations: ", len(df)) + + prompt_counts = df["post_process_conv"].value_counts() + # Select the top 20 most frequent prompts + top_prompts = prompt_counts.head(20) + print(top_prompts) + + # Determine the percentile count + percentile_cutoff = prompt_counts.quantile(args.percentile) + print(f"{args.percentile*100} percentile count: {percentile_cutoff}") + + # prompts that are more common than the percentile cutoff + high_frequency_prompts = prompt_counts[prompt_counts > percentile_cutoff].index + print( + f"Number of high frequency prompts: {len(high_frequency_prompts)}/{len(prompt_counts)}" + ) + + # initialize a new column dedup_tag + dedup_tags = np.array( + [{"high_freq": False, "sampled": True} for _ in range(len(df))] + ) + high_freq_groups = df.groupby("post_process_conv") + for prompt in tqdm(high_frequency_prompts): + df_high_freq = high_freq_groups.get_group(prompt) + sampled_indices = df_high_freq.sample( + n=int(percentile_cutoff), random_state=42 + ).index + dedup_tags[df_high_freq.index] = {"high_freq": True, "sampled": False} + dedup_tags[sampled_indices] = {"high_freq": True, "sampled": True} + + df["dedup_tag"] = dedup_tags + + # drop intermediate columns (post_process_conv) + df = df.drop(columns=["post_process_conv"]) + + df.to_json( + os.path.join(output_dir, "dedup.json"), + orient="records", + indent=4, + force_ascii=False, + ) diff --git a/fastchat/serve/monitor/elo_analysis.py b/fastchat/serve/monitor/elo_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..4982b2f0db15a781ff2f2a73c9e22505d1752ce5 --- /dev/null +++ b/fastchat/serve/monitor/elo_analysis.py @@ -0,0 +1,622 @@ +import argparse +import ast +from collections import defaultdict +import datetime +import json +import math +import pickle +from pytz import timezone +from functools import partial + +import numpy as np +import pandas as pd +import plotly.express as px +from tqdm import tqdm +from transformers import AutoTokenizer + +from fastchat.model.model_registry import get_model_info +from fastchat.serve.monitor.basic_stats import get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data + +pd.options.display.float_format = "{:.2f}".format + + +def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): + rating = defaultdict(lambda: INIT_RATING) + + for rd, model_a, model_b, winner in battles[ + ["model_a", "model_b", "winner"] + ].itertuples(): + ra = rating[model_a] + rb = rating[model_b] + ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) + eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) + if winner == "model_a": + sa = 1 + elif winner == "model_b": + sa = 0 + elif winner == "tie" or winner == "tie (bothbad)": + sa = 0.5 + else: + raise Exception(f"unexpected vote {winner}") + rating[model_a] += K * (sa - ea) + rating[model_b] += K * (1 - sa - eb) + + return dict(rating) + + +def get_bootstrap_result(battles, func_compute_elo, num_round=1000): + rows = [] + for i in tqdm(range(num_round), desc="bootstrap"): + tmp_battles = battles.sample(frac=1.0, replace=True) + rows.append(func_compute_elo(tmp_battles)) + df = pd.DataFrame(rows) + return df[df.median().sort_values(ascending=False).index] + + +def compute_elo_mle_with_tie( + df, SCALE=400, BASE=10, INIT_RATING=1000, sample_weight=None +): + from sklearn.linear_model import LogisticRegression + + ptbl_a_win = pd.pivot_table( + df[df["winner"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + ptbl_tie = pd.pivot_table( + df[df["winner"].isin(["tie", "tie (bothbad)"])], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + ptbl_tie = ptbl_tie + ptbl_tie.T + ptbl_b_win = pd.pivot_table( + df[df["winner"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + ptbl_win = ptbl_a_win * 2 + ptbl_b_win.T * 2 + ptbl_tie + + models = pd.Series(np.arange(len(ptbl_win.index)), index=ptbl_win.index) + + p = len(models) + X = np.zeros([p * (p - 1) * 2, p]) + Y = np.zeros(p * (p - 1) * 2) + + cur_row = 0 + sample_weights = [] + for m_a in ptbl_win.index: + for m_b in ptbl_win.columns: + if m_a == m_b: + continue + # if nan skip + if math.isnan(ptbl_win.loc[m_a, m_b]) or math.isnan(ptbl_win.loc[m_b, m_a]): + continue + X[cur_row, models[m_a]] = +math.log(BASE) + X[cur_row, models[m_b]] = -math.log(BASE) + Y[cur_row] = 1.0 + sample_weights.append(ptbl_win.loc[m_a, m_b]) + + X[cur_row + 1, models[m_a]] = math.log(BASE) + X[cur_row + 1, models[m_b]] = -math.log(BASE) + Y[cur_row + 1] = 0.0 + sample_weights.append(ptbl_win.loc[m_b, m_a]) + cur_row += 2 + X = X[:cur_row] + Y = Y[:cur_row] + + lr = LogisticRegression(fit_intercept=False, penalty=None) + lr.fit(X, Y, sample_weight=sample_weights) + elo_scores = SCALE * lr.coef_[0] + INIT_RATING + if "mixtral-8x7b-instruct-v0.1" in models.index: + elo_scores += 1114 - elo_scores[models["mixtral-8x7b-instruct-v0.1"]] + return pd.Series(elo_scores, index=models.index).sort_values(ascending=False) + + +def get_median_elo_from_bootstrap(bootstrap_df): + median = dict(bootstrap_df.quantile(0.5)) + median = {k: int(v + 0.5) for k, v in median.items()} + return median + + +def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None): + # Times each model wins as Model A + a_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting times each model wins as Model B + b_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting number of A-B pairs + num_battles_ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + + # Computing the proportion of wins for each model as A and as B + # against all other models + row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / ( + num_battles_ptbl + num_battles_ptbl.T + ) + + if model_order is None: + prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False) + model_order = list(prop_wins.keys()) + + if limit_show_number is not None: + model_order = model_order[:limit_show_number] + + # Arrange ordering according to proprition of wins + row_beats_col = row_beats_col_freq.loc[model_order, model_order] + return row_beats_col + + +def visualize_leaderboard_table(rating): + models = list(rating.keys()) + models.sort(key=lambda k: -rating[k]) + + emoji_dict = { + 1: "🥇", + 2: "🥈", + 3: "🥉", + } + + md = "" + md += "| Rank | Model | Elo Rating | Description |\n" + md += "| --- | --- | --- | --- |\n" + for i, model in enumerate(models): + rank = i + 1 + minfo = get_model_info(model) + emoji = emoji_dict.get(rank, "") + md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n" + + return md + + +def visualize_pairwise_win_fraction(battles, model_order, scale=1): + row_beats_col = compute_pairwise_win_fraction(battles, model_order) + fig = px.imshow( + row_beats_col, + color_continuous_scale="RdBu", + text_auto=".2f", + height=700 * scale, + width=700 * scale, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}" + ) + + return fig + + +def visualize_battle_count(battles, model_order, scale=1): + ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + battle_counts = ptbl + ptbl.T + fig = px.imshow( + battle_counts.loc[model_order, model_order], + text_auto=True, + height=700 * scale, + width=700 * scale, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}" + ) + return fig + + +def visualize_average_win_rate(battles, limit_show_number, scale=1): + row_beats_col_freq = compute_pairwise_win_fraction( + battles, None, limit_show_number=limit_show_number + ) + fig = px.bar( + row_beats_col_freq.mean(axis=1).sort_values(ascending=False), + text_auto=".2f", + height=500 * scale, + width=700 * scale, + ) + fig.update_layout( + yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False + ) + return fig + + +def visualize_bootstrap_elo_rating(df, df_final, limit_show_number, scale=1): + bars = ( + pd.DataFrame( + dict( + lower=df.quantile(0.025), + rating=df_final, + upper=df.quantile(0.975), + ) + ) + .reset_index(names="model") + .sort_values("rating", ascending=False) + ) + bars = bars[:limit_show_number] + bars["error_y"] = bars["upper"] - bars["rating"] + bars["error_y_minus"] = bars["rating"] - bars["lower"] + bars["rating_rounded"] = np.round(bars["rating"]) + fig = px.scatter( + bars, + x="model", + y="rating", + error_y="error_y", + error_y_minus="error_y_minus", + text="rating_rounded", + height=700, + width=700 * scale, + ) + fig.update_layout(xaxis_title="Model", yaxis_title="Rating") + return fig + + +def limit_user_votes(battles, daily_vote_per_user): + from datetime import datetime + + print("Before limiting user votes: ", len(battles)) + # add date + battles["date"] = battles["tstamp"].apply( + lambda x: datetime.fromtimestamp(x).strftime("%Y-%m-%d") + ) + + battles_new = pd.DataFrame() + for date in battles["date"].unique(): + # only take the first daily_vote_per_user votes per judge per day + df_today = battles[battles["date"] == date] + df_sub = df_today.groupby("judge").head(daily_vote_per_user) + + # add df_sub to a new dataframe + battles_new = pd.concat([battles_new, df_sub]) + print("After limiting user votes: ", len(battles_new)) + return battles_new + + +def get_model_pair_stats(battles): + battles["ordered_pair"] = battles.apply( + lambda x: tuple(sorted([x["model_a"], x["model_b"]])), axis=1 + ) + + model_pair_stats = {} + + for index, row in battles.iterrows(): + pair = row["ordered_pair"] + if pair not in model_pair_stats: + model_pair_stats[pair] = {"win": 0, "loss": 0, "tie": 0} + + if row["winner"] in ["tie", "tie (bothbad)"]: + model_pair_stats[pair]["tie"] += 1 + elif row["winner"] == "model_a" and row["model_a"] == min(pair): + model_pair_stats[pair]["win"] += 1 + elif row["winner"] == "model_b" and row["model_b"] == min(pair): + model_pair_stats[pair]["win"] += 1 + else: + model_pair_stats[pair]["loss"] += 1 + + return model_pair_stats + + +def outlier_detect( + model_pair_stats, + battles, + max_vote=100, + randomized=False, + alpha=0.05, + c_param=0.5, + user_list=None, +): + if user_list is None: + # only check user who has >= 5 votes to save compute + user_vote_cnt = battles["judge"].value_counts() + user_list = user_vote_cnt[user_vote_cnt >= 5].index.tolist() + print("#User to be checked: ", len(user_list)) + + bad_user_list = [] + for user in user_list: + flag = False + p_upper = [] + p_lower = [] + df_2 = battles[battles["judge"] == user] + for row in df_2.iterrows(): + if len(p_upper) >= max_vote: + break + + model_pair = tuple(sorted([row[1]["model_a"], row[1]["model_b"]])) + + if row[1]["winner"] in ["tie", "tie (bothbad)"]: + vote = 0.5 + elif row[1]["winner"] == "model_a" and row[1]["model_a"] == model_pair[0]: + vote = 1 + elif row[1]["winner"] == "model_b" and row[1]["model_b"] == model_pair[0]: + vote = 1 + else: + vote = 0 + + stats = model_pair_stats[model_pair] + # count all votes + # ratings = np.array( + # [1] * stats["win"] + [0.5] * stats["tie"] + [0] * stats["loss"] + # ) + + # only count win and loss + ratings = np.array([1] * stats["win"] + [0] * stats["loss"]) + if randomized: + noise = np.random.uniform(-1e-5, 1e-5, len(ratings)) + ratings += noise + vote += np.random.uniform(-1e-5, 1e-5) + + p_upper += [(ratings <= vote).mean()] + p_lower += [(ratings >= vote).mean()] + + M_upper = np.prod(1 / (2 * np.array(p_upper))) + M_lower = np.prod(1 / (2 * np.array(p_lower))) + + # M_upper = np.prod((1 - c_param) / (c_param * np.array(p_upper) ** c_param)) + # M_lower = np.prod((1 - c_param) / (c_param * np.array(p_lower) ** c_param)) + if (M_upper > 1 / alpha) or (M_lower > 1 / alpha): + print(f"Identify bad user with {len(p_upper)} votes") + flag = True + break + if flag: + bad_user_list.append({"user_id": user, "votes": len(p_upper)}) + print("Bad user length: ", len(bad_user_list)) + print(bad_user_list) + + bad_user_id_list = [x["user_id"] for x in bad_user_list] + # remove bad users + battles = battles[~battles["judge"].isin(bad_user_id_list)] + return battles + + +def filter_long_conv(row): + threshold = 768 + for conversation_type in ["conversation_a", "conversation_b"]: + cur_conv = row[conversation_type] + num_tokens_all = sum([turn["num_tokens"] for turn in cur_conv]) + if num_tokens_all >= threshold: + return True + return False + + +def report_elo_analysis_results( + battles_json, + rating_system="bt", + num_bootstrap=100, + exclude_models=[], + langs=[], + exclude_tie=False, + exclude_unknown_lang=False, + daily_vote_per_user=None, + run_outlier_detect=False, + scale=1, + filter_func=lambda x: True, +): + battles = pd.DataFrame(battles_json) + + tqdm.pandas(desc=f"Processing using {filter_func.__name__}") + filtered_indices = battles.progress_apply(filter_func, axis=1) + battles = battles[filtered_indices] + + battles = battles.sort_values(ascending=True, by=["tstamp"]) + + if len(langs) > 0: + battles = battles[battles["language"].isin(langs)] + if exclude_unknown_lang: + battles = battles[~battles["language"].str.contains("unknown")] + + # remove excluded models + battles = battles[ + ~( + battles["model_a"].isin(exclude_models) + | battles["model_b"].isin(exclude_models) + ) + ] + + # Only use anonymous votes + battles = battles[battles["anony"]].reset_index(drop=True) + battles_no_ties = battles[~battles["winner"].str.contains("tie")] + if exclude_tie: + battles = battles_no_ties + + if daily_vote_per_user is not None: + battles = limit_user_votes(battles, daily_vote_per_user) + + if run_outlier_detect: + model_pair_stats = get_model_pair_stats(battles) + battles = outlier_detect(model_pair_stats, battles) + + print(f"Number of battles: {len(battles)}") + # Online update + elo_rating_online = compute_elo(battles) + + if rating_system == "bt": + bootstrap_df = get_bootstrap_result( + battles, compute_elo_mle_with_tie, num_round=num_bootstrap + ) + elo_rating_final = compute_elo_mle_with_tie(battles) + elif rating_system == "elo": + bootstrap_df = get_bootstrap_result( + battles, compute_elo, num_round=num_bootstrap + ) + elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) + elo_rating_final = elo_rating_median + + model_order = list(elo_rating_final.keys()) + + model_rating_q025 = bootstrap_df.quantile(0.025) + model_rating_q975 = bootstrap_df.quantile(0.975) + + # compute ranking based on CI + ranking = {} + for i, model_a in enumerate(model_order): + ranking[model_a] = 1 + for j, model_b in enumerate(model_order): + if i == j: + continue + if model_rating_q025[model_b] > model_rating_q975[model_a]: + ranking[model_a] += 1 + + # leaderboard_table_df: elo rating, variance, 95% interval, number of battles + leaderboard_table_df = pd.DataFrame( + { + "rating": elo_rating_final, + "variance": bootstrap_df.var(), + "rating_q975": bootstrap_df.quantile(0.975), + "rating_q025": bootstrap_df.quantile(0.025), + "num_battles": battles["model_a"] + .value_counts() + .add(battles["model_b"].value_counts(), fill_value=0), + "final_ranking": pd.Series(ranking), + } + ) + + model_order.sort(key=lambda k: -elo_rating_final[k]) + limit_show_number = int(25 * scale) + model_order = model_order[:limit_show_number] + + # Plots + leaderboard_table = visualize_leaderboard_table(elo_rating_final) + win_fraction_heatmap = visualize_pairwise_win_fraction( + battles_no_ties, model_order, scale=scale + ) + battle_count_heatmap = visualize_battle_count( + battles_no_ties, model_order, scale=scale + ) + average_win_rate_bar = visualize_average_win_rate( + battles_no_ties, limit_show_number, scale=scale + ) + bootstrap_elo_rating = visualize_bootstrap_elo_rating( + bootstrap_df, elo_rating_final, limit_show_number, scale=scale + ) + + last_updated_tstamp = battles["tstamp"].max() + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + return { + "rating_system": rating_system, + "elo_rating_online": elo_rating_online, + "elo_rating_final": elo_rating_final, + "leaderboard_table": leaderboard_table, + "win_fraction_heatmap": win_fraction_heatmap, + "battle_count_heatmap": battle_count_heatmap, + "average_win_rate_bar": average_win_rate_bar, + "bootstrap_elo_rating": bootstrap_elo_rating, + "last_updated_datetime": last_updated_datetime, + "last_updated_tstamp": last_updated_tstamp, + "bootstrap_df": bootstrap_df, + "leaderboard_table_df": leaderboard_table_df, + } + + +def pretty_print_elo_rating(rating): + model_order = list(rating.keys()) + model_order.sort(key=lambda k: -rating[k]) + for i, model in enumerate(model_order): + print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--num-bootstrap", type=int, default=100) + parser.add_argument( + "--rating-system", type=str, choices=["bt", "elo"], default="bt" + ) + parser.add_argument("--exclude-models", type=str, nargs="+", default=[]) + parser.add_argument("--exclude-tie", action="store_true", default=False) + parser.add_argument("--exclude-unknown-lang", action="store_true", default=False) + parser.add_argument("--exclude-url", action="store_true", default=False) + parser.add_argument("--langs", type=str, nargs="+", default=[]) + parser.add_argument("--daily-vote-per-user", type=int, default=None) + parser.add_argument("--run-outlier-detect", action="store_true", default=False) + parser.add_argument("--category", nargs="+", default=["full"]) + parser.add_argument("--scale", type=float, default=1) + args = parser.parse_args() + + np.random.seed(42) + + if args.clean_battle_file: + # Read data from a cleaned battle files + battles = pd.read_json(args.clean_battle_file) + else: + # Read data from all log files + log_files = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files) + + filter_func_map = { + "full": lambda x: True, + "long": filter_long_conv, + "chinese": lambda x: x["language"] == "Chinese", + "english": lambda x: x["language"] == "English", + } + assert all( + [cat in filter_func_map for cat in args.category] + ), f"Invalid category: {args.category}" + + results = {} + for cat in args.category: + filter_func = filter_func_map[cat] + results[cat] = report_elo_analysis_results( + battles, + rating_system=args.rating_system, + num_bootstrap=args.num_bootstrap, + exclude_models=args.exclude_models, + langs=args.langs, + exclude_tie=args.exclude_tie, + exclude_unknown_lang=args.exclude_unknown_lang, + daily_vote_per_user=args.daily_vote_per_user, + run_outlier_detect=args.run_outlier_detect, + scale=args.scale, + filter_func=filter_func, + ) + + for cat in args.category: + print(f"# Results for {cat} conversations") + print("# Online Elo") + pretty_print_elo_rating(results[cat]["elo_rating_online"]) + print("# Median") + pretty_print_elo_rating(results[cat]["elo_rating_final"]) + print(f"last update : {results[cat]['last_updated_datetime']}") + + last_updated_tstamp = results[cat]["last_updated_tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + print(f"last update : {cutoff_date}") + + with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: + pickle.dump(results, fout) diff --git a/fastchat/serve/monitor/inspect_conv.py b/fastchat/serve/monitor/inspect_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a680a419bd9d11d0db85afbc21c0063a2ae36df7 --- /dev/null +++ b/fastchat/serve/monitor/inspect_conv.py @@ -0,0 +1,87 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd +from tqdm import tqdm + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4, 5]: + for day in range(1, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + + num_servers = 14 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def pretty_print_conversation(messages): + for role, msg in messages: + print(f"[[{role}]]: {msg}") + + +def inspect_convs(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + if "states" not in row: + continue + if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: + continue + + model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] + if row["type"] == "leftvote": + winner, loser = model_names[0], model_names[1] + winner_conv, loser_conv = row["states"][0], row["states"][1] + elif row["type"] == "rightvote": + loser, winner = model_names[0], model_names[1] + loser_conv, winner_conv = row["states"][0], row["states"][1] + + if loser == "bard" and winner == "vicuna-13b": + print("=" * 20) + print(f"Winner: {winner}") + pretty_print_conversation(winner_conv["messages"]) + print(f"Loser: {loser}") + pretty_print_conversation(loser_conv["messages"]) + print("=" * 20) + input() + + # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: + # print("=" * 20) + # print(f"Model A: {model_names[0]}") + # pretty_print_conversation(row["states"][0]["messages"]) + # print(f"Model B: {model_names[1]}") + # pretty_print_conversation(row["states"][1]["messages"]) + # print("=" * 20) + # input() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + inspect_convs(log_files) diff --git a/fastchat/serve/monitor/intersect_conv_file.py b/fastchat/serve/monitor/intersect_conv_file.py new file mode 100644 index 0000000000000000000000000000000000000000..9eadd7cd57510ecbbd23798d55b079c69aac1a12 --- /dev/null +++ b/fastchat/serve/monitor/intersect_conv_file.py @@ -0,0 +1,25 @@ +""" +Take the intersection of two conversation files. + +Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument("--conv-id", type=str, required=True) + parser.add_argument("--out-file", type=str, default="intersect.json") + args = parser.parse_args() + + conv_id_objs = json.load(open(args.conv_id, "r")) + conv_ids = set(x["conversation_id"] for x in conv_id_objs) + + objs = json.load(open(args.input, "r")) + after_objs = [x for x in objs if x["conversation_id"] in conv_ids] + + print(f"#in: {len(objs)}, #out: {len(after_objs)}") + json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/fastchat/serve/monitor/leaderboard_csv_to_html.py b/fastchat/serve/monitor/leaderboard_csv_to_html.py new file mode 100644 index 0000000000000000000000000000000000000000..ad52e7b2b6e234ed33a51d516e9d682addd1e0eb --- /dev/null +++ b/fastchat/serve/monitor/leaderboard_csv_to_html.py @@ -0,0 +1,51 @@ +""" +Convert a leaderboard csv file to html table used in the blog. + +Usage: +python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv +""" +import argparse + +import numpy as np + +from fastchat.serve.monitor.monitor import load_leaderboard_table_csv + + +def model_hyperlink(model_name, link): + return f' {model_name} ' + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + args = parser.parse_args() + + data = load_leaderboard_table_csv(args.input, add_hyperlink=False) + headers = [ + "Model", + "MT-bench (score)", + "Arena Elo rating", + "MMLU", + "License", + ] + values = [] + for item in data: + row = [] + for key in headers: + value = item[key] + row.append(value) + row[0] = model_hyperlink(item["Model"], item["Link"]) + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + + for value in values: + row = "" + for x in value: + try: + if np.isnan(x): + x = "-" + except TypeError: + pass + row += f" {x} " + row += "" + print(row) diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..e112ab613cdb09feb7d3dd16efc791b9d82f04be --- /dev/null +++ b/fastchat/serve/monitor/monitor.py @@ -0,0 +1,1033 @@ +""" +Live monitor of the website statistics and leaderboard. + +Dependency: +sudo apt install pkg-config libicu-dev +pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate +""" + +import argparse +import ast +import json +import pickle +import os +import threading +import time + +import pandas as pd +import gradio as gr +import numpy as np + +from fastchat.constants import SURVEY_LINK +from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data +from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results +from fastchat.utils import build_logger, get_window_url_params_js + + +notebook_url = ( + "https://colab.research.google.com/drive/1KdwokPjirkTmpO_P1WByFNFiqxWQquwH" +) + +basic_component_values = [None] * 6 +leader_component_values = [None] * 5 + + +def make_default_md_1(mirror=False): + link_color = "#1976D2" # This color should be clear in both light and dark mode + leaderboard_md = f""" + # 🏆 LMSYS Chatbot Arena Leaderboard + [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena) + """ + + return leaderboard_md + + +def make_default_md_2(mirror=False): + mirror_str = "This is a mirror of the live leaderboard created and maintained by the LMSYS Organization. Please link to leaderboard.lmsys.org for citation purposes." + leaderboard_md = f""" +{mirror_str if mirror else ""} + +LMSYS Chatbot Arena is a crowdsourced open platform for LLM evals. We've collected over 1,000,000 human pairwise comparisons to rank LLMs with the Bradley-Terry model and display the model ratings in Elo-scale. +You can find more details in our paper. **Chatbot arena is dependent on community participation, please contribute by casting your vote!** + +{SURVEY_LINK} +""" + + return leaderboard_md + + +def make_arena_leaderboard_md(arena_df, last_updated_time, vision=False): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + space = "   " + + leaderboard_md = f""" +Total #models: **{total_models}**.{space} Total #votes: **{"{:,}".format(total_votes)}**.{space} Last updated: {last_updated_time}. +""" + if not vision: + leaderboard_md += """ +📣 **NEW!** View leaderboard for different categories (e.g., coding, long user query)! This is still in preview and subject to change. +""" + + leaderboard_md += f""" +Code to recreate leaderboard tables and plots in this [notebook]({notebook_url}). You can contribute your vote at [chat.lmsys.org](https://chat.lmsys.org)! +""" + return leaderboard_md + + +def make_category_arena_leaderboard_md(arena_df, arena_subset_df, name="Overall"): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + space = "   " + total_subset_votes = sum(arena_subset_df["num_battles"]) // 2 + total_subset_models = len(arena_subset_df) + leaderboard_md = f"""### {cat_name_to_explanation[name]} +#### {space} #models: **{total_subset_models} ({round(total_subset_models/total_models *100)}%)** {space} #votes: **{"{:,}".format(total_subset_votes)} ({round(total_subset_votes/total_votes * 100)}%)**{space} +""" + return leaderboard_md + + +def make_full_leaderboard_md(): + leaderboard_md = """ +Three benchmarks are displayed: **Arena Elo**, **MT-Bench** and **MMLU**. +- [Chatbot Arena](https://chat.lmsys.org/?arena) - a crowdsourced, randomized battle platform. We use 500K+ user votes to compute model strength. +- [MT-Bench](https://arxiv.org/abs/2306.05685): a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. +- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot): a test to measure a model's multitask accuracy on 57 tasks. + +💻 Code: The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). +The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). +Higher values are better for all benchmarks. Empty cells mean not available. +""" + return leaderboard_md + + +def make_leaderboard_md_live(elo_results): + leaderboard_md = f""" +# Leaderboard +Last updated: {elo_results["last_updated_datetime"]} +{elo_results["leaderboard_table"]} +""" + return leaderboard_md + + +def arena_hard_title(date): + arena_hard_title = f""" +Last Updated: {date} + +**Arena-Hard-Auto v0.1** - an automatic evaluation tool for instruction-tuned LLMs with 500 challenging user queries curated from Chatbot Arena. + +We prompt GPT-4-Turbo as judge to compare the models' responses against a baseline model (default: GPT-4-0314). If you are curious to see how well your model might perform on Chatbot Arena, we recommend trying Arena-Hard-Auto. Check out our paper for more details about how Arena-Hard-Auto works as an fully automated data pipeline converting crowdsourced data into high-quality benchmarks -> +[[Paper](https://arxiv.org/abs/2406.11939) | [Repo](https://github.com/lm-sys/arena-hard-auto)] + """ + return arena_hard_title + + +def update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names +): + log_files = get_log_files(max_num_files) + + # Leaderboard + if elo_results_file is None: # Do live update + ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None + battles = clean_battle_data( + log_files, exclude_model_names, ban_ip_list=ban_ip_list + ) + elo_results = report_elo_analysis_results(battles, scale=2) + + leader_component_values[0] = make_leaderboard_md_live(elo_results) + leader_component_values[1] = elo_results["win_fraction_heatmap"] + leader_component_values[2] = elo_results["battle_count_heatmap"] + leader_component_values[3] = elo_results["bootstrap_elo_rating"] + leader_component_values[4] = elo_results["average_win_rate_bar"] + + # Basic stats + basic_stats = report_basic_stats(log_files) + md0 = f"Last updated: {basic_stats['last_updated_datetime']}" + + md1 = "### Action Histogram\n" + md1 += basic_stats["action_hist_md"] + "\n" + + md2 = "### Anony. Vote Histogram\n" + md2 += basic_stats["anony_vote_hist_md"] + "\n" + + md3 = "### Model Call Histogram\n" + md3 += basic_stats["model_hist_md"] + "\n" + + md4 = "### Model Call (Last 24 Hours)\n" + md4 += basic_stats["num_chats_last_24_hours"] + "\n" + + basic_component_values[0] = md0 + basic_component_values[1] = basic_stats["chat_dates_bar"] + basic_component_values[2] = md1 + basic_component_values[3] = md2 + basic_component_values[4] = md3 + basic_component_values[5] = md4 + + +def update_worker( + max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names +): + while True: + tic = time.time() + update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names + ) + durtaion = time.time() - tic + print(f"update duration: {durtaion:.2f} s") + time.sleep(max(interval - durtaion, 0)) + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return basic_component_values + leader_component_values + + +def model_hyperlink(model_name, link): + return f'{model_name}' + + +def load_leaderboard_table_csv(filename, add_hyperlink=True): + lines = open(filename).readlines() + heads = [v.strip() for v in lines[0].split(",")] + rows = [] + for i in range(1, len(lines)): + row = [v.strip() for v in lines[i].split(",")] + for j in range(len(heads)): + item = {} + for h, v in zip(heads, row): + if h == "Arena Elo rating": + if v != "-": + v = int(ast.literal_eval(v)) + else: + v = np.nan + elif h == "MMLU": + if v != "-": + v = round(ast.literal_eval(v) * 100, 1) + else: + v = np.nan + elif h == "MT-bench (win rate %)": + if v != "-": + v = round(ast.literal_eval(v[:-1]), 1) + else: + v = np.nan + elif h == "MT-bench (score)": + if v != "-": + v = round(ast.literal_eval(v), 2) + else: + v = np.nan + item[h] = v + if add_hyperlink: + item["Model"] = model_hyperlink(item["Model"], item["Link"]) + rows.append(item) + + return rows + + +def build_basic_stats_tab(): + empty = "Loading ..." + basic_component_values[:] = [empty, None, empty, empty, empty, empty] + + md0 = gr.Markdown(empty) + gr.Markdown("#### Figure 1: Number of model calls and votes") + plot_1 = gr.Plot(show_label=False) + with gr.Row(): + with gr.Column(): + md1 = gr.Markdown(empty) + with gr.Column(): + md2 = gr.Markdown(empty) + with gr.Row(): + with gr.Column(): + md3 = gr.Markdown(empty) + with gr.Column(): + md4 = gr.Markdown(empty) + return [md0, plot_1, md1, md2, md3, md4] + + +def get_full_table(arena_df, model_table_df, model_to_score): + values = [] + for i in range(len(model_table_df)): + row = [] + model_key = model_table_df.iloc[i]["key"] + model_name = model_table_df.iloc[i]["Model"] + # model display name + row.append(model_name) + if model_key in arena_df.index: + idx = arena_df.index.get_loc(model_key) + row.append(round(arena_df.iloc[idx]["rating"])) + else: + row.append(np.nan) + if model_name in model_to_score: + row.append(model_to_score[model_name]) + else: + row.append(np.nan) + row.append(model_table_df.iloc[i]["MT-bench (score)"]) + row.append(model_table_df.iloc[i]["MMLU"]) + # Organization + row.append(model_table_df.iloc[i]["Organization"]) + # license + row.append(model_table_df.iloc[i]["License"]) + + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + return values + + +def arena_hard_process(leaderboard_table_file, filepath): + arena_hard = pd.read_csv(filepath) + leaderboard_table = pd.read_csv(leaderboard_table_file) + links = leaderboard_table.get("Link") + display_name = leaderboard_table.get("Model") + model_name = leaderboard_table.get("key") + organization = leaderboard_table.get("Organization") + + info = {} + for i in range(len(model_name)): + model_info = {} + model_info["display"] = display_name[i] + model_info["link"] = links[i] + model_info["org"] = organization[i] + info[model_name[i]] = model_info + + organization = [] + for i in range(len(arena_hard)): + assert ( + arena_hard.loc[i, "model"] in info + ), f"need to update leaderboard_table info by adding {arena_hard.loc[i, 'model']}" + organization.append(info[arena_hard.loc[i, "model"]]["org"]) + link = info[arena_hard.loc[i, "model"]]["link"] + arena_hard.loc[i, "model"] = model_hyperlink( + info[arena_hard.loc[i, "model"]]["display"], link + ) + + arena_hard.insert( + loc=len(arena_hard.columns), column="Organization", value=organization + ) + + rankings = recompute_final_ranking(arena_hard) + arena_hard.insert(loc=0, column="Rank* (UB)", value=rankings) + return arena_hard + + +def create_ranking_str(ranking, ranking_difference): + if ranking_difference > 0: + return f"{int(ranking)} \u2191" + elif ranking_difference < 0: + return f"{int(ranking)} \u2193" + else: + return f"{int(ranking)}" + + +def recompute_final_ranking(arena_df): + # compute ranking based on CI + ranking = {} + for i, model_a in enumerate(arena_df.index): + ranking[model_a] = 1 + for j, model_b in enumerate(arena_df.index): + if i == j: + continue + if ( + arena_df.loc[model_b]["rating_q025"] + > arena_df.loc[model_a]["rating_q975"] + ): + ranking[model_a] += 1 + return list(ranking.values()) + + +def highlight_top_models(df): + def highlight_max_rank(s): + # Pastel Yellow with transparency, rgba(red, green, blue, alpha) + highlight_color = "rgba(255, 255, 128, 0.2)" # 50% transparent + if int(s["Rank* (UB)"].replace("↑", "").replace("↓", "")) == 1: + return [f"background-color: {highlight_color}" for _ in s] + else: + return ["" for _ in s] + + # Apply and return the styled DataFrame + return df.apply(highlight_max_rank, axis=1) + + +def get_arena_table(arena_df, model_table_df, arena_subset_df=None): + arena_df = arena_df.sort_values( + by=["final_ranking", "rating"], ascending=[True, False] + ) + arena_df["final_ranking"] = recompute_final_ranking(arena_df) + arena_df = arena_df.sort_values( + by=["final_ranking", "rating"], ascending=[True, False] + ) + + # sort by rating + if arena_subset_df is not None: + # filter out models not in the arena_df + arena_subset_df = arena_subset_df[arena_subset_df.index.isin(arena_df.index)] + arena_subset_df = arena_subset_df.sort_values(by=["rating"], ascending=False) + arena_subset_df["final_ranking"] = recompute_final_ranking(arena_subset_df) + # keep only the models in the subset in arena_df and recompute final_ranking + arena_df = arena_df[arena_df.index.isin(arena_subset_df.index)] + # recompute final ranking + arena_df["final_ranking"] = recompute_final_ranking(arena_df) + + # assign ranking by the order + arena_subset_df["final_ranking_no_tie"] = range(1, len(arena_subset_df) + 1) + arena_df["final_ranking_no_tie"] = range(1, len(arena_df) + 1) + # join arena_df and arena_subset_df on index + arena_df = arena_subset_df.join( + arena_df["final_ranking"], rsuffix="_global", how="inner" + ) + arena_df["ranking_difference"] = ( + arena_df["final_ranking_global"] - arena_df["final_ranking"] + ) + + arena_df = arena_df.sort_values( + by=["final_ranking", "rating"], ascending=[True, False] + ) + arena_df["final_ranking"] = arena_df.apply( + lambda x: create_ranking_str(x["final_ranking"], x["ranking_difference"]), + axis=1, + ) + + arena_df["final_ranking"] = arena_df["final_ranking"].astype(str) + + values = [] + for i in range(len(arena_df)): + row = [] + model_key = arena_df.index[i] + try: # this is a janky fix for where the model key is not in the model table (model table and arena table dont contain all the same models) + model_name = model_table_df[model_table_df["key"] == model_key][ + "Model" + ].values[0] + # rank + ranking = arena_df.iloc[i].get("final_ranking") or i + 1 + row.append(ranking) + if arena_subset_df is not None: + row.append(arena_df.iloc[i].get("ranking_difference") or 0) + # model display name + row.append(model_name) + # elo rating + rating = f"{round(arena_df.iloc[i]['rating'])}" + row.append(rating) + upper_diff = round( + arena_df.iloc[i]["rating_q975"] - arena_df.iloc[i]["rating"] + ) + lower_diff = round( + arena_df.iloc[i]["rating"] - arena_df.iloc[i]["rating_q025"] + ) + row.append(f"+{upper_diff}/-{lower_diff}") + # num battles + row.append(round(arena_df.iloc[i]["num_battles"])) + # Organization + row.append( + model_table_df[model_table_df["key"] == model_key][ + "Organization" + ].values[0] + ) + # license + row.append( + model_table_df[model_table_df["key"] == model_key]["License"].values[0] + ) + cutoff_date = model_table_df[model_table_df["key"] == model_key][ + "Knowledge cutoff date" + ].values[0] + if cutoff_date == "-": + row.append("Unknown") + else: + row.append(cutoff_date) + values.append(row) + except Exception as e: + print(f"{model_key} - {e}") + return values + + +key_to_category_name = { + "full": "Overall", + "dedup": "De-duplicate Top Redundant Queries (soon to be default)", + "math": "Math", + "if": "Instruction Following", + "multiturn": "Multi-Turn", + "coding": "Coding", + "hard_6": "Hard Prompts (Overall)", + "hard_english_6": "Hard Prompts (English)", + "long_user": "Longer Query", + "english": "English", + "chinese": "Chinese", + "french": "French", + "german": "German", + "spanish": "Spanish", + "russian": "Russian", + "japanese": "Japanese", + "korean": "Korean", + "no_tie": "Exclude Ties", + "no_short": "Exclude Short Query (< 5 tokens)", + "no_refusal": "Exclude Refusal", + "overall_limit_5_user_vote": "overall_limit_5_user_vote", + "full_old": "Overall (Deprecated)", +} +cat_name_to_explanation = { + "Overall": "Overall Questions", + "De-duplicate Top Redundant Queries (soon to be default)": "De-duplicate top redundant queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", + "Math": "Math", + "Instruction Following": "Instruction Following", + "Multi-Turn": "Multi-Turn Conversation (>= 2 turns)", + "Coding": "Coding: whether conversation contains code snippets", + "Hard Prompts (Overall)": "Hard Prompts (Overall): details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)", + "Hard Prompts (English)": "Hard Prompts (English), note: the delta is to English Category. details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)", + "Longer Query": "Longer Query (>= 500 tokens)", + "English": "English Prompts", + "Chinese": "Chinese Prompts", + "French": "French Prompts", + "German": "German Prompts", + "Spanish": "Spanish Prompts", + "Russian": "Russian Prompts", + "Japanese": "Japanese Prompts", + "Korean": "Korean Prompts", + "Exclude Ties": "Exclude Ties and Bothbad", + "Exclude Short Query (< 5 tokens)": "Exclude Short User Query (< 5 tokens)", + "Exclude Refusal": 'Exclude model responses with refusal (e.g., "I cannot answer")', + "overall_limit_5_user_vote": "overall_limit_5_user_vote", + "Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).", +} +cat_name_to_baseline = { + "Hard Prompts (English)": "English", +} + + +def update_leaderboard_df(arena_table_vals): + elo_datarame = pd.DataFrame( + arena_table_vals, + columns=[ + "Rank* (UB)", + "Delta", + "Model", + "Arena Elo", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + ) + + # goal: color the rows based on the rank with styler + def highlight_max(s): + # all items in S which contain up arrow should be green, down arrow should be red, otherwise black + return [ + ( + "color: green; font-weight: bold" + if "\u2191" in v + else "color: red; font-weight: bold" + if "\u2193" in v + else "" + ) + for v in s + ] + + def highlight_rank_max(s): + return [ + ( + "color: green; font-weight: bold" + if v > 0 + else "color: red; font-weight: bold" + if v < 0 + else "" + ) + for v in s + ] + + return elo_datarame.style.apply(highlight_max, subset=["Rank* (UB)"]).apply( + highlight_rank_max, subset=["Delta"] + ) + + +def build_arena_tab( + elo_results, + model_table_df, + default_md, + vision=False, + show_plot=False, +): + if elo_results is None: + gr.Markdown( + """ ## Coming soon...! + """, + ) + return + + arena_dfs = {} + category_elo_results = {} + last_updated_time = elo_results["full"]["last_updated_datetime"].split(" ")[0] + for k in key_to_category_name.keys(): + if k not in elo_results: + continue + arena_dfs[key_to_category_name[k]] = elo_results[k]["leaderboard_table_df"] + category_elo_results[key_to_category_name[k]] = elo_results[k] + + arena_df = arena_dfs["Overall"] + + def update_leaderboard_and_plots(category): + arena_subset_df = arena_dfs[category] + arena_subset_df = arena_subset_df[arena_subset_df["num_battles"] > 300] + elo_subset_results = category_elo_results[category] + + baseline_category = cat_name_to_baseline.get(category, "Overall") + arena_df = arena_dfs[baseline_category] + arena_values = get_arena_table( + arena_df, + model_table_df, + arena_subset_df=arena_subset_df if category != "Overall" else None, + ) + if category != "Overall": + arena_values = update_leaderboard_df(arena_values) + # arena_values = highlight_top_models(arena_values) + arena_values = gr.Dataframe( + headers=[ + "Rank* (UB)", + "Delta", + "Model", + "Arena Score", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + datatype=[ + "str", + "number", + "markdown", + "number", + "str", + "number", + "str", + "str", + "str", + ], + value=arena_values, + elem_id="arena_leaderboard_dataframe", + height=800, + column_widths=[70, 70, 200, 90, 100, 90, 120, 150, 100], + wrap=True, + ) + else: + arena_values = gr.Dataframe( + headers=[ + "Rank* (UB)", + "🤖 Model", + "⭐ Arena Score", + "📊 95% CI", + "🗳️ Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + datatype=[ + "str", + "markdown", + "number", + "str", + "number", + "str", + "str", + "str", + ], + value=arena_values, + elem_id="arena_leaderboard_dataframe", + height=800, + column_widths=[70, 190, 100, 100, 90, 140, 150, 100], + wrap=True, + ) + + p1 = elo_subset_results["win_fraction_heatmap"] + p2 = elo_subset_results["battle_count_heatmap"] + p3 = elo_subset_results["bootstrap_elo_rating"] + p4 = elo_subset_results["average_win_rate_bar"] + more_stats_md = f"""## More Statistics for Chatbot Arena - {category} + """ + leaderboard_md = make_category_arena_leaderboard_md( + arena_df, arena_subset_df, name=category + ) + return arena_values, p1, p2, p3, p4, more_stats_md, leaderboard_md + + arena_df = arena_dfs["Overall"] + + p1 = category_elo_results["Overall"]["win_fraction_heatmap"] + p2 = category_elo_results["Overall"]["battle_count_heatmap"] + p3 = category_elo_results["Overall"]["bootstrap_elo_rating"] + p4 = category_elo_results["Overall"]["average_win_rate_bar"] + + # arena table + arena_table_vals = get_arena_table(arena_df, model_table_df) + + md = make_arena_leaderboard_md(arena_df, last_updated_time, vision=vision) + gr.Markdown(md, elem_id="leaderboard_markdown") + with gr.Row(): + with gr.Column(scale=2): + category_dropdown = gr.Dropdown( + choices=list(arena_dfs.keys()), + label="Category", + value="Overall", + ) + default_category_details = make_category_arena_leaderboard_md( + arena_df, arena_df, name="Overall" + ) + with gr.Column(scale=4, variant="panel"): + category_deets = gr.Markdown( + default_category_details, elem_id="category_deets" + ) + + arena_vals = pd.DataFrame( + arena_table_vals, + columns=[ + "Rank* (UB)", + "Model", + "Arena Score", + "95% CI", + "Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + ) + elo_display_df = gr.Dataframe( + headers=[ + "Rank* (UB)", + "🤖 Model", + "⭐ Arena Elo", + "📊 95% CI", + "🗳️ Votes", + "Organization", + "License", + "Knowledge Cutoff", + ], + datatype=[ + "str", + "markdown", + "number", + "str", + "number", + "str", + "str", + "str", + ], + # value=highlight_top_models(arena_vals.style), + value=arena_vals.style, + elem_id="arena_leaderboard_dataframe", + height=800, + column_widths=[70, 190, 100, 100, 90, 130, 150, 100], + wrap=True, + ) + + gr.Markdown( + f""" +***Rank (UB)**: model's ranking (upper-bound), defined by one + the number of models that are statistically better than the target model. +Model A is statistically better than model B when A's lower-bound score is greater than B's upper-bound score (in 95% confidence interval). +See Figure 1 below for visualization of the confidence intervals of model scores. + +Note: in each category, we exclude models with fewer than 300 votes as their confidence intervals can be large. +""", + elem_id="leaderboard_markdown", + ) + + if not vision: + # only live update the text tab + leader_component_values[:] = [default_md, p1, p2, p3, p4] + + if show_plot: + more_stats_md = gr.Markdown( + f"""## More Statistics for Chatbot Arena (Overall)""", + elem_id="leaderboard_header_markdown", + ) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Confidence Intervals on Model Strength (via Bootstrapping)", + elem_id="plot-title", + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)", + elem_id="plot-title", + ) + plot_4 = gr.Plot(p4, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Fraction of Model A Wins for All Non-tied A vs. B Battles", + elem_id="plot-title", + ) + plot_1 = gr.Plot(p1, show_label=False, elem_id="plot-container") + with gr.Column(): + gr.Markdown( + "#### Figure 4: Battle Count for Each Combination of Models (without Ties)", + elem_id="plot-title", + ) + plot_2 = gr.Plot(p2, show_label=False) + category_dropdown.change( + update_leaderboard_and_plots, + inputs=[category_dropdown], + outputs=[ + elo_display_df, + plot_1, + plot_2, + plot_3, + plot_4, + more_stats_md, + category_deets, + ], + ) + return [plot_1, plot_2, plot_3, plot_4] + + +def build_full_leaderboard_tab(elo_results, model_table_df, model_to_score): + arena_df = elo_results["full"]["leaderboard_table_df"] + md = make_full_leaderboard_md() + gr.Markdown(md, elem_id="leaderboard_markdown") + full_table_vals = get_full_table(arena_df, model_table_df, model_to_score) + gr.Dataframe( + headers=[ + "Model", + "Arena Elo", + "Arena-Hard-Auto", + "MT-bench", + "MMLU", + "Organization", + "License", + ], + datatype=["markdown", "number", "number", "number", "number", "str", "str"], + value=full_table_vals, + elem_id="full_leaderboard_dataframe", + column_widths=[200, 100, 110, 100, 70, 130, 150], + height=800, + wrap=True, + ) + + +def build_leaderboard_tab( + elo_results_file, + leaderboard_table_file, + arena_hard_leaderboard, + show_plot=False, + mirror=False, +): + if elo_results_file is None: # Do live update + default_md = "Loading ..." + p1 = p2 = p3 = p4 = None + else: + with open(elo_results_file, "rb") as fin: + elo_results = pickle.load(fin) + if "text" in elo_results: + elo_results_text = elo_results["text"] + elo_results_vision = elo_results["vision"] + else: + elo_results_text = elo_results + elo_results_vision = None + + default_md = make_default_md_1(mirror=mirror) + default_md_2 = make_default_md_2(mirror=mirror) + + with gr.Row(): + with gr.Column(scale=4): + md_1 = gr.Markdown(default_md, elem_id="leaderboard_markdown") + with gr.Column(scale=1): + vote_button = gr.Button("Vote!", link="https://chat.lmsys.org") + md2 = gr.Markdown(default_md_2, elem_id="leaderboard_markdown") + if leaderboard_table_file: + data = load_leaderboard_table_csv(leaderboard_table_file) + model_table_df = pd.DataFrame(data) + + with gr.Tabs() as tabs: + with gr.Tab("Arena", id=0): + gr_plots = build_arena_tab( + elo_results_text, + model_table_df, + default_md, + show_plot=show_plot, + ) + with gr.Tab("📣 NEW: Arena (Vision)", id=1): + build_arena_tab( + elo_results_vision, + model_table_df, + default_md, + vision=True, + show_plot=show_plot, + ) + if arena_hard_leaderboard is not None: + with gr.Tab("Arena-Hard-Auto", id=2): + dataFrame = arena_hard_process( + leaderboard_table_file, arena_hard_leaderboard + ) + date = dataFrame["date"][0] + dataFrame = dataFrame.drop( + columns=["rating_q025", "rating_q975", "date"] + ) + dataFrame["CI"] = dataFrame.CI.map(ast.literal_eval) + dataFrame["CI"] = dataFrame.CI.map(lambda x: f"+{x[1]}/-{x[0]}") + dataFrame = dataFrame.rename( + columns={ + "model": "Model", + "score": "Win-rate", + "CI": "95% CI", + "avg_tokens": "Average Tokens", + } + ) + model_to_score = {} + for i in range(len(dataFrame)): + model_to_score[dataFrame.loc[i, "Model"]] = dataFrame.loc[ + i, "Win-rate" + ] + md = arena_hard_title(date) + gr.Markdown(md, elem_id="leaderboard_markdown") + gr.DataFrame( + dataFrame, + datatype=[ + "markdown" if col == "Model" else "str" + for col in dataFrame.columns + ], + elem_id="arena_hard_leaderboard", + height=800, + wrap=True, + column_widths=[70, 190, 80, 80, 90, 150], + ) + + with gr.Tab("Full Leaderboard", id=3): + build_full_leaderboard_tab( + elo_results_text, model_table_df, model_to_score + ) + + if not show_plot: + gr.Markdown( + """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis! + If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). + """, + elem_id="leaderboard_markdown", + ) + else: + pass + + from fastchat.serve.gradio_web_server import acknowledgment_md + + with gr.Accordion( + "Citation", + open=True, + ): + citation_md = """ + ### Citation + Please cite the following paper if you find our leaderboard or dataset helpful. + ``` + @misc{chiang2024chatbot, + title={Chatbot Arena: An Open Platform for Evaluating LLMs by Human Preference}, + author={Wei-Lin Chiang and Lianmin Zheng and Ying Sheng and Anastasios Nikolas Angelopoulos and Tianle Li and Dacheng Li and Hao Zhang and Banghua Zhu and Michael Jordan and Joseph E. Gonzalez and Ion Stoica}, + year={2024}, + eprint={2403.04132}, + archivePrefix={arXiv}, + primaryClass={cs.AI} + } + """ + gr.Markdown(citation_md, elem_id="leaderboard_markdown") + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + return [md_1] + gr_plots + + +def build_demo(elo_results_file, leaderboard_table_file, arena_hard_leaderboard): + from fastchat.serve.gradio_web_server import block_css + + text_size = gr.themes.sizes.text_lg + # load theme from theme.json + theme = gr.themes.Default.load("theme.json") + # set text size to large + theme.text_size = text_size + theme.set( + button_large_text_size="40px", + button_small_text_size="40px", + button_large_text_weight="1000", + button_small_text_weight="1000", + button_shadow="*shadow_drop_lg", + button_shadow_hover="*shadow_drop_lg", + checkbox_label_shadow="*shadow_drop_lg", + button_shadow_active="*shadow_inset", + button_secondary_background_fill="*primary_300", + button_secondary_background_fill_dark="*primary_700", + button_secondary_background_fill_hover="*primary_200", + button_secondary_background_fill_hover_dark="*primary_500", + button_secondary_text_color="*primary_800", + button_secondary_text_color_dark="white", + ) + + with gr.Blocks( + title="Chatbot Arena Leaderboard", + # theme=gr.themes.Default(text_size=text_size), + theme=theme, + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Leaderboard", id=0): + leader_components = build_leaderboard_tab( + elo_results_file, + leaderboard_table_file, + arena_hard_leaderboard, + show_plot=True, + mirror=False, + ) + + with gr.Tab("Basic Stats", id=1): + basic_components = build_basic_stats_tab() + + url_params = gr.JSON(visible=False) + demo.load( + load_demo, + [url_params], + basic_components + leader_components, + js=get_window_url_params_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--update-interval", type=int, default=300) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--elo-results-file", type=str) + parser.add_argument("--leaderboard-table-file", type=str) + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + parser.add_argument("--password", type=str, default=None, nargs="+") + parser.add_argument("--arena-hard-leaderboard", type=str) + args = parser.parse_args() + + logger = build_logger("monitor", "monitor.log") + logger.info(f"args: {args}") + + if args.elo_results_file is None: # Do live update + update_thread = threading.Thread( + target=update_worker, + args=( + args.max_num_files, + args.update_interval, + args.elo_results_file, + args.ban_ip_file, + args.exclude_model_names, + ), + ) + update_thread.start() + + demo = build_demo( + args.elo_results_file, args.leaderboard_table_file, args.arena_hard_leaderboard + ) + demo.queue( + default_concurrency_limit=args.concurrency_count, + status_update_rate=10, + api_open=False, + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=(args.password[0], args.password[1]) if args.password else None, + ) diff --git a/fastchat/serve/monitor/summarize_cluster.py b/fastchat/serve/monitor/summarize_cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..b461a68b2bfeeaf1a660103b491edf7f0b255a21 --- /dev/null +++ b/fastchat/serve/monitor/summarize_cluster.py @@ -0,0 +1,85 @@ +""" +Usage: +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100 +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200 +""" +import argparse +import pickle + +import pandas as pd + +from fastchat.llm_judge.common import ( + chat_completion_openai, + chat_completion_openai_azure, + chat_completion_anthropic, +) +from fastchat.conversation import get_conv_template + + +def truncate_string(s, l): + half = int(l // 2) + return s[:half] + s[-half:] if len(s) > l else s + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="gpt-3.5-turbo") + parser.add_argument("--num-prompts", type=int, default=100) + args = parser.parse_args() + + model = args.model + + cluster_infos = pickle.load(open(args.input_file, "rb")) + num_total_prompts = sum([x[0] for x in cluster_infos]) + + topics = [] + percentages = [] + for i, info in enumerate(cluster_infos): + num_samples, topk_prompts, random_prompts = info + percentage = num_samples / num_total_prompts + print( + f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" + ) + instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." + split = int(args.num_prompts * 0.8) + prompt = "\n".join( + [truncate_string(x, l=200) for x in topk_prompts[:split]] + + [ + truncate_string(x, l=200) + for x in random_prompts[: args.num_prompts - split] + ] + ) + prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." + + if "azure-" in model: + template_name = "chatgpt" + completion_func = chat_completion_openai_azure + elif "gpt" in model: + template_name = "chatgpt" + completion_func = chat_completion_openai + elif "claude" in model: + template_name = "claude" + completion_func = chat_completion_anthropic + + conv = get_conv_template(template_name) + conv.set_system_message(instruct) + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + + topic = completion_func(model, conv, temperature=0, max_tokens=256) + print(topic) + + topics.append(topic) + percentages.append(round(percentage, 6)) + + print() + print(f"topics: {topics}") + print(f"percentages: {percentages}") + + # save the informations + df = pd.DataFrame() + df["topic"] = topics + df["percentage"] = percentages + + df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records") diff --git a/fastchat/serve/monitor/tag_openai_moderation.py b/fastchat/serve/monitor/tag_openai_moderation.py new file mode 100644 index 0000000000000000000000000000000000000000..b80703388b2a47bf372a09bbed81d7bede2bd412 --- /dev/null +++ b/fastchat/serve/monitor/tag_openai_moderation.py @@ -0,0 +1,63 @@ +""" +Add OpenAI moderation API results to all conversations. +""" +import argparse +from concurrent.futures import ThreadPoolExecutor +import json +import os +import time + +import openai +import requests +from tqdm import tqdm + + +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + + +def tag_moderation(text): + result = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + result = openai.Moderation.create(input=text)["results"][0] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return result + + +def tag_openai_moderation(x): + conv = x["conversation_a"] + user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) + result = tag_moderation(user_prompts) + x["openai_moderation"] = result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument("--first-n", type=int) + args = parser.parse_args() + + battles = json.load(open(args.input)) + + if args.first_n: + battles = battles[: args.first_n] + + with ThreadPoolExecutor(args.parallel) as executor: + for line in tqdm( + executor.map(tag_openai_moderation, battles), total=len(battles) + ): + pass + + output = args.input.replace(".json", "_tagged.json") + with open(output, "w") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/fastchat/serve/monitor/topic_clustering.py b/fastchat/serve/monitor/topic_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..3d58e56bf3a749fe13480ea0eb7f7d2d412accc7 --- /dev/null +++ b/fastchat/serve/monitor/topic_clustering.py @@ -0,0 +1,292 @@ +""" + +Usage: +python3 topic_clustering.py --in arena.json --english-only --min-length 32 +python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536 +""" +import argparse +import json +import pickle +import string +import time + +import numpy as np +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim +from sklearn.cluster import KMeans, AgglomerativeClustering +import torch +from tqdm import tqdm +from openai import OpenAI + +from fastchat.utils import detect_language + + +def remove_punctuation(input_string): + # Make a translator object to remove all punctuation + translator = str.maketrans("", "", string.punctuation) + + # Use the translator object to remove the punctuation + no_punct = input_string.translate(translator) + return no_punct + + +def read_texts(input_file, min_length, max_length, english_only): + visited = set() + texts = [] + + lines = json.load(open(input_file, "r")) + + for l in tqdm(lines): + if "text" in l: + line_texts = [l["text"]] + elif "conversation_a" in l: + line_texts = [ + x["content"] for x in l["conversation_a"] if x["role"] == "user" + ] + elif "conversation" in l: + line_texts = [ + x["content"] for x in l["conversation"] if x["role"] == "user" + ] + elif "turns" in l: + line_texts = l["turns"] + + for text in line_texts: + text = text.strip() + + # Filter language + if english_only: + lang = detect_language(text) + if lang != "English": + continue + + # Filter short or long prompts + if min_length: + if len(text) < min_length: + continue + + if max_length: + if len(text) > max_length: + continue + + # De-duplication + words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) + words = "".join(words) + if words in visited: + continue + + visited.add(words) + texts.append(text) + return np.array(texts) + + +def get_embeddings(texts, model_name, batch_size): + if model_name == "text-embedding-ada-002": + client = OpenAI() + texts = texts.tolist() + + embeddings = [] + for i in tqdm(range(0, len(texts), batch_size)): + text = texts[i : i + batch_size] + responses = client.embeddings.create(input=text, model=model_name).data + embeddings.extend([data.embedding for data in responses]) + embeddings = torch.tensor(embeddings) + else: + model = SentenceTransformer(model_name) + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + device="cuda", + convert_to_tensor=True, + ) + + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.cpu() + + +def run_k_means(embeddings, num_clusters): + np.random.seed(42) + clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") + clustering_model.fit(embeddings.numpy()) + centers = torch.from_numpy(clustering_model.cluster_centers_) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + new_centers = torch.empty_like(centers) + for i, c in enumerate(classes): + new_labels[labels == c] = i + new_centers[i] = centers[c] + return new_centers, new_labels + + +def run_agg_cluster(embeddings, num_clusters): + np.random.seed(42) + clustering_model = AgglomerativeClustering(n_clusters=num_clusters) + clustering_model.fit(embeddings) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def run_hdbscan_cluster(embeddings): + import hdbscan + + np.random.seed(42) + clusterer = hdbscan.HDBSCAN(min_cluster_size=10) + labels = torch.from_numpy(clusterer.fit_predict(embeddings)) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def get_topk_indices(centers, labels, embeddings, topk): + indices = [] + arange = torch.arange(len(labels)) + counts = torch.unique(labels, return_counts=True)[1] + topk = min(topk, counts.min().item()) + for i in range(len(centers)): + tmp_indices = labels == i + tmp_arange = arange[tmp_indices] + tmp_embeddings = embeddings[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) + return torch.cat(indices) + + +def print_topk(texts, labels, topk_indices, show_cut_off): + ret = "" + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + + ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" + for idx in topk_indices[k]: + ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" + ret += "=" * 40 + "\n\n" + + return ret + + +def get_cluster_info(texts, labels, topk_indices): + np.random.seed(42) + + cluster_info = [] + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + topk_prompts = [] + for idx in topk_indices[k]: + topk_prompts.append(texts[idx]) + random_prompts = [] + for idx in range(len(topk_indices)): + random_prompts.append(np.random.choice(texts)) + cluster_info.append((num_samples, topk_prompts, random_prompts)) + + return cluster_info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="all-mpnet-base-v2") + # default="all-MiniLM-L12-v2") + # default="multi-qa-distilbert-cos-v1") + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--min-length", type=int) + parser.add_argument("--max-length", type=int) + parser.add_argument("--english-only", action="store_true") + parser.add_argument("--num-clusters", type=int, default=20) + parser.add_argument( + "--cluster-alg", + type=str, + choices=["kmeans", "aggcls", "HDBSCAN"], + default="kmeans", + ) + parser.add_argument("--show-top-k", type=int, default=200) + parser.add_argument("--show-cut-off", type=int, default=512) + parser.add_argument("--save-embeddings", action="store_true") + parser.add_argument("--embeddings-file", type=str, default=None) + args = parser.parse_args() + + num_clusters = args.num_clusters + show_top_k = args.show_top_k + show_cut_off = args.show_cut_off + + texts = read_texts( + args.input_file, args.min_length, args.max_length, args.english_only + ) + print(f"#text: {len(texts)}") + + if args.embeddings_file is None: + embeddings = get_embeddings(texts, args.model, args.batch_size) + if args.save_embeddings: + # allow saving embedding to save time and money + torch.save(embeddings, "embeddings.pt") + else: + embeddings = torch.load(args.embeddings_file) + print(f"embeddings shape: {embeddings.shape}") + + if args.cluster_alg == "kmeans": + centers, labels = run_k_means(embeddings, num_clusters) + elif args.cluster_alg == "aggcls": + centers, labels = run_agg_cluster(embeddings, num_clusters) + elif args.cluster_alg == "HDBSCAN": + centers, labels = run_hdbscan_cluster(embeddings) + else: + raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") + + topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) + topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) + num_clusters = len(centers) + + # Dump results + filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" + print(topk_str) + with open(filename_prefix + "_topk.txt", "w") as fout: + fout.write(topk_str) + + with open(filename_prefix + "_all.jsonl", "w") as fout: + for i in range(len(centers)): + tmp_indices = labels == i + tmp_embeddings = embeddings[tmp_indices] + tmp_texts = texts[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + + for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): + obj = {"cluster": i, "text": text, "sim": score.item()} + fout.write(json.dumps(obj, ensure_ascii=False) + "\n") + + cluster_info = get_cluster_info(texts, labels, topk_indices) + with open(filename_prefix + "_cluster.pkl", "wb") as fout: + pickle.dump(cluster_info, fout) diff --git a/fastchat/serve/monitor/vote_time_stats/README.md b/fastchat/serve/monitor/vote_time_stats/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7404839b98db34ded6a5fd9c8749fef27766b8a8 --- /dev/null +++ b/fastchat/serve/monitor/vote_time_stats/README.md @@ -0,0 +1,5 @@ +# Instructions + +First run `analyze_data.py` to collect metadata of all votes. + +Then run `plot.py` to get the plot. You need to edit these files for proper input or output filename diff --git a/fastchat/serve/monitor/vote_time_stats/analyze_data.py b/fastchat/serve/monitor/vote_time_stats/analyze_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4bdd18694f55fb2e29aed357dd2637e0477966e3 --- /dev/null +++ b/fastchat/serve/monitor/vote_time_stats/analyze_data.py @@ -0,0 +1,120 @@ +import datetime +import glob +import json +from collections import deque +import tqdm + + +def _serialize_json(data): + # Serialize JSON with sorted keys and no whitespace + return json.dumps(data, sort_keys=True, separators=(",", ":")).encode("utf-8") + + +types = { + "share", + "chat", + "flag", + "bothbad_vote", + "downvote", + "leftvote", + "rightvote", + "upvote", + "tievote", +} + +chat_dict = {} +cache_queue = deque() + + +def process_record(r): + ip = r.pop("ip", None) + tstamp = r.pop("tstamp") + mtype = r.pop("type") + start = r.pop("start", None) + finish = r.pop("finish", None) + + # gabagge collect to save memory + while len(cache_queue) > 100000: + outdated = cache_queue.popleft() + poped_item = chat_dict.pop(outdated["key"], None) + if poped_item is None: + # TODO: this sometimes happens, need to investigate what happens. in theory the chat dict should be synced with the queue, unless there are duplicated items + print("Error: Key to GC does not exist.") + + assert mtype in types + if mtype == "chat": + key = _serialize_json(r["state"]) + # TODO: add the string length of the last reply for analyzing voting time per character. + chat_dict[key] = { + "timestamp": tstamp, + "start": start, + "finish": finish, + "conv_id": r["state"]["conv_id"], + } + cache_queue.append({"key": key, "timestamp": tstamp}) + elif mtype in ("leftvote", "rightvote", "bothbad_vote", "tievote"): + left_key = _serialize_json(r["states"][0]) + right_key = _serialize_json(r["states"][1]) + if left_key not in chat_dict: + # TODO: this sometimes happens, it means we have the vote but we cannot find previous chat, need to investigate what happens + print( + f'WARNING: Cannot find vote context for conversation {r["states"][0]["conv_id"]}' + ) + return + if right_key not in chat_dict: + print( + f'WARNING: Cannot find vote context for conversation {r["states"][1]["conv_id"]}' + ) + return + vote_time_data = { + "timestamp": tstamp, + "type": mtype, + "left": chat_dict[left_key], + "right": chat_dict[right_key], + "ip": ip, + } + return vote_time_data + + return None + + +def process_file(infile: str, outfile: str): + with open(infile) as f: + records = [] + for l in f.readlines(): + l = l.strip() + if l: + try: + r = json.loads(l) + if r.get("tstamp") is not None: + records.append(r) + except Exception: + pass + # sort the record in case there are out-of-order records + records.sort(key=lambda x: x["tstamp"]) + + with open(outfile, "a") as outfile: + for r in records: + try: + output = process_record(r) + if output is not None: + outfile.write(json.dumps(output) + "\n") + except Exception as e: + import traceback + + print("Error:", e) + traceback.print_exc() + + +today = datetime.datetime.today().isoformat().split("T", 1)[0] +# sort it to make sure the date is continuous for each server +filelist = sorted(glob.glob("/mnt/disks/data/fastchat_logs/server*/202*-*-*-conv.json")) +filelist = [ + f for f in filelist if today not in f +] # skip today because date could be partial + +# TODO: change this to select different range of data +filelist = [f for f in filelist if "2024-03-" in f] + +for f in tqdm.tqdm(filelist): + process_file(f, "output.jsonl") diff --git a/fastchat/serve/monitor/vote_time_stats/plot.py b/fastchat/serve/monitor/vote_time_stats/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ba6fd5e37a4524dda3c7ae7a05313fefbe030a86 --- /dev/null +++ b/fastchat/serve/monitor/vote_time_stats/plot.py @@ -0,0 +1,66 @@ +import json +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np + + +infile = "output.jsonl" +date = "2024-03" # used in the plot + +durations = [] + +with open(infile) as f: + for line in f: + data = json.loads(line) + l = data["left"]["finish"] + r = data["right"]["finish"] + v = data["timestamp"] + durations.append(v - max(l, r)) + +print( + f"Avg: {np.mean(durations)}, Median: {np.median(durations)}, Max: {np.max(durations)}" +) + +# Define the new cutoff and number of bins +cutoff = 200.0 # New cutoff value +num_bins_inside_cutoff = 20 # Number of bins from 0 to cutoff + +for i, n in enumerate(durations): + if n > cutoff: + durations[i] = cutoff + 0.5 * cutoff / num_bins_inside_cutoff + +# Create bin edges from 0 to cutoff, with the specified number of bins +bin_edges = np.linspace(0, cutoff, num_bins_inside_cutoff + 1) + +# Adjusting the overflow bin to end at 110 +overflow_cap = ( + cutoff + cutoff / num_bins_inside_cutoff +) # Adjust as needed based on distribution +bin_edges = np.append(bin_edges, overflow_cap) + +# Create the plot with custom bins +sns.histplot( + durations, bins=bin_edges, kde=False +) # Turn off KDE for clearer bar visibility +plt.title(f'Distribution of "time to vote" {date}') +plt.xlabel("Duration (seconds)") +plt.ylabel("Frequency") + +# Highlight the overflow bin +plt.axvline(x=cutoff, color="red", linestyle="--") +plt.text( + cutoff + 1, plt.ylim()[1] * 0.9, "Overflow", color="red", ha="left" +) # Adjust text alignment + +# Customizing x-axis labels to hide the "110" +ax = plt.gca() # Get current axis +labels = [item.get_text() for item in ax.get_xticklabels()] +if "110" in labels: + labels[labels.index("110")] = "" # Replace "110" with an empty string +ax.set_xticklabels(labels) + +# Ensure nothing is cut off in the plot +plt.tight_layout() + +# Save the plot to a file with high resolution +plt.savefig(f"duration_distribution_time_to_vote_{date}.png", dpi=300) diff --git a/fastchat/serve/multi_model_worker.py b/fastchat/serve/multi_model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6266fe0449b59529825c7e8982cccf7d871e0e --- /dev/null +++ b/fastchat/serve/multi_model_worker.py @@ -0,0 +1,300 @@ +""" +A multi-model worker that contains multiple sub-works one for each model. This +supports running a list of models on the same machine so that they can +(potentially) share the same background weights. + +Each model can have one or more model names. + +This multi-model worker assumes the models shares some underlying weights and +thus reports the combined queue lengths for health checks. + +We recommend using this with multiple Peft models (with `peft` in the name) +where all Peft models are trained on the exact same base model. +""" +import argparse +import asyncio +import dataclasses +import logging +import json +import os +import time +from typing import List, Union +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_conversation_template, +) +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.serve.inference import generate_stream +from fastchat.serve.model_worker import ModelWorker, worker_id, logger +from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length + + +# We store both the underlying workers and a mapping from their model names to +# the worker instance. This makes it easy to fetch the appropriate worker for +# each API call. +workers = [] +worker_map = {} +app = FastAPI() + + +def release_worker_semaphore(): + workers[0].semaphore.release() + + +def acquire_worker_semaphore(): + if workers[0].semaphore is None: + # Share the same semaphore for all workers because + # all workers share the same GPU. + semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency) + for w in workers: + w.semaphore = semaphore + return workers[0].semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +# Note: for all the calls below, we make a hard assumption that the caller +# includes the model name in the payload, otherwise we can't figure out which +# underlying sub-worker to call. + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + embedding = worker.get_embeddings(params) + background_tasks = create_background_tasks() + return JSONResponse(content=embedding, background=background_tasks) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_multi_model_worker(): + # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST + # of the model args but we'll override one to have an append action that + # supports multiple values. + parser = argparse.ArgumentParser(conflict_handler="resolve") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + # Override the model path to be repeated and align it with model names. + parser.add_argument( + "--model-path", + type=str, + default=[], + action="append", + help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + action="append", + help="One or more model names. Values must be aligned with `--model-path` values.", + ) + parser.add_argument( + "--conv-template", + type=str, + default=None, + action="append", + help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + if args.model_names is None: + args.model_names = [[x.split("/")[-1]] for x in args.model_path] + + if args.conv_template is None: + args.conv_template = [None] * len(args.model_path) + elif len(args.conv_template) == 1: # Repeat the same template + args.conv_template = args.conv_template * len(args.model_path) + + # Launch all workers + workers = [] + for conv_template, model_path, model_names in zip( + args.conv_template, args.model_path, args.model_names + ): + w = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + model_names, + args.limit_worker_concurrency, + args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=conv_template, + ) + workers.append(w) + for model_name in model_names: + worker_map[model_name] = w + + # Register all models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_multi_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ffee96bf4f9ea39f49e9a309cab92a651fdfb9 --- /dev/null +++ b/fastchat/serve/openai_api_server.py @@ -0,0 +1,939 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio +import argparse +import json +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import aiohttp +import fastapi +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +import httpx + +from pydantic_settings import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import ( + WORKER_API_TIMEOUT, + WORKER_API_EMBEDDING_BATCH_SIZE, + ErrorCode, +) +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.utils import build_logger + +logger = build_logger("openai_api_server", "openai_api_server.log") + +conv_template_map = {} + +fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) + + +async def fetch_remote(url, pload=None, name=None): + async with aiohttp.ClientSession(timeout=fetch_timeout) as session: + async with session.post(url, json=pload) as response: + chunks = [] + if response.status != 200: + ret = { + "text": f"{response.reason}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return json.dumps(ret) + + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks) + + if name is not None: + res = json.loads(output) + if name != "": + res = res[name] + return res + + return output + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + api_keys: Optional[List[str]] = None + + +app_settings = AppSettings() +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if app_settings.api_keys: + if auth is None or (token := auth.credentials) not in app_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).model_dump(), status_code=400 + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + + models = await fetch_remote(controller_address + "/list_models", None, "models") + if request.model not in models: + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens, worker_addr): + if ( + not isinstance(max_tokens, int) or max_tokens <= 0 + ): # model worker not support max_tokens=None + max_tokens = 1024 * 1024 + + context_len = await fetch_remote( + worker_addr + "/model_details", {"model": request.model}, "context_length" + ) + token_num = await fetch_remote( + worker_addr + "/count_token", + {"model": request.model, "prompt": prompt}, + "count", + ) + length = min(max_tokens, context_len - token_num) + + if length <= 0: + return None, create_error_response( + ErrorCode.CONTEXT_OVERFLOW, + f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.", + ) + + return length, None + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'top_p'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, inp): + if isinstance(inp, str): + inp = [inp] + elif isinstance(inp, list): + if isinstance(inp[0], int): + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) + inp = [decoding.decode(inp)] + elif isinstance(inp[0], list): + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) + inp = [decoding.decode(text) for text in inp] + + return inp + + +def create_openai_logprobs(logprob_dict): + """Create OpenAI-style logprobs.""" + return LogProbs(**logprob_dict) if logprob_dict is not None else None + + +def _add_to_set(s, new_stop): + if not s: + return + if isinstance(s, str): + new_stop.add(s) + else: + new_stop.update(s) + + +async def get_gen_params( + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, +) -> Dict[str, Any]: + conv = await get_conv(model_name, worker_addr) + conv = Conversation( + name=conv["name"], + system_template=conv["system_template"], + system_message=conv["system_message"], + roles=conv["roles"], + messages=list(conv["messages"]), # prevent in-place modification + offset=conv["offset"], + sep_style=SeparatorStyle(conv["sep_style"]), + sep=conv["sep"], + sep2=conv["sep2"], + stop_str=conv["stop_str"], + stop_token_ids=conv["stop_token_ids"], + ) + + if isinstance(messages, str): + prompt = messages + images = [] + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.set_system_message(message["content"]) + elif msg_role == "user": + if type(message["content"]) == list: + image_list = [ + item["image_url"]["url"] + for item in message["content"] + if item["type"] == "image_url" + ] + text_list = [ + item["text"] + for item in message["content"] + if item["type"] == "text" + ] + + # TODO(chris): This only applies to LLaVA model. Implement an image_token string in the conv template. + text = "\n" * len(image_list) + text += "\n".join(text_list) + conv.append_message(conv.roles[0], (text, image_list)) + else: + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + images = conv.get_images() + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "logprobs": logprobs, + "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "max_new_tokens": max_tokens, + "echo": echo, + "stop_token_ids": conv.stop_token_ids, + } + + if len(images) > 0: + gen_params["images"] = images + + if best_of is not None: + gen_params.update({"best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + + new_stop = set() + _add_to_set(stop, new_stop) + _add_to_set(conv.stop_str, new_stop) + + gen_params["stop"] = list(new_stop) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def get_worker_address(model_name: str) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + worker_addr = await fetch_remote( + controller_address + "/get_worker_address", {"model": model_name}, "address" + ) + + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +async def get_conv(model_name: str, worker_addr: str): + conv_template = conv_template_map.get((worker_addr, model_name)) + if conv_template is None: + conv_template = await fetch_remote( + worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" + ) + conv_template_map[(worker_addr, model_name)] = conv_template + return conv_template + + +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def show_available_models(): + controller_address = app_settings.controller_address + ret = await fetch_remote(controller_address + "/refresh_all_workers") + models = await fetch_remote(controller_address + "/list_models", None, "models") + + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + + +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if isinstance(content, str): + content = json.loads(content) + + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + if "usage" in content: + task_usage = UsageInfo.model_validate(content["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + + previous_text = "" + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.model_dump_json(exclude_none=True)}\n\n" + yield "data: [DONE]\n\n" + + +@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) +async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + request.prompt = process_input(request.model, request.prompt) + + worker_addr = await get_worker_address(request.model) + for text in request.prompt: + max_tokens, error_check_ret = await check_length( + request, text, request.max_tokens, worker_addr + ) + if error_check_ret is not None: + return error_check_ret + + if isinstance(max_tokens, int) and max_tokens < request.max_tokens: + request.max_tokens = max_tokens + + if request.stream: + generator = generate_completion_stream_generator( + request, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + best_of=request.best_of, + use_beam_search=request.use_beam_search, + ) + for i in range(request.n): + content = asyncio.create_task( + generate_completion(gen_params, worker_addr) + ) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.model_validate(content["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.model_validate(usage) + ) + + +async def generate_completion_stream_generator( + request: CompletionRequest, n: int, worker_addr: str +): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + ) + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.model_dump_json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + buffer = b"" + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while (chunk_end := buffer.find(delimiter)) >= 0: + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + if not chunk: + continue + yield json.loads(chunk.decode()) + + +async def generate_completion(payload: Dict[str, Any], worker_addr: str): + return await fetch_remote(worker_addr + "/worker_generate", payload, "") + + +@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) +@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) +async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): + """Creates embeddings for the text""" + if request.model is None: + request.model = model_name + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + request.input = process_input(request.model, request.input) + + data = [] + token_num = 0 + batch_size = WORKER_API_EMBEDDING_BATCH_SIZE + batches = [ + request.input[i : min(i + batch_size, len(request.input))] + for i in range(0, len(request.input), batch_size) + ] + for num_batch, batch in enumerate(batches): + payload = { + "model": request.model, + "input": batch, + "encoding_format": request.encoding_format, + } + embedding = await get_embedding(payload) + if "error_code" in embedding and embedding["error_code"] != 0: + return create_error_response(embedding["error_code"], embedding["text"]) + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embedding["embedding"]) + ] + token_num += embedding["token_num"] + return EmbeddingsResponse( + data=data, + model=request.model, + usage=UsageInfo( + prompt_tokens=token_num, + total_tokens=token_num, + completion_tokens=None, + ), + ).model_dump(exclude_none=True) + + +async def get_embedding(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + model_name = payload["model"] + worker_addr = await get_worker_address(model_name) + + embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) + return json.loads(embedding) + + +### GENERAL API - NOT OPENAI COMPATIBLE ### + + +@app.post("/api/v1/token_check") +async def count_tokens(request: APITokenCheckRequest): + """ + Checks the token count for each message in your list + This is not part of the OpenAI API spec. + """ + checkedList = [] + for item in request.prompts: + worker_addr = await get_worker_address(item.model) + + context_len = await fetch_remote( + worker_addr + "/model_details", + {"prompt": item.prompt, "model": item.model}, + "context_length", + ) + + token_num = await fetch_remote( + worker_addr + "/count_token", + {"prompt": item.prompt, "model": item.model}, + "count", + ) + + can_fit = True + if token_num + item.max_tokens > context_len: + can_fit = False + + checkedList.append( + APITokenCheckResponseItem( + fits=can_fit, contextLength=context_len, tokenCount=token_num + ) + ) + + return APITokenCheckResponse(prompts=checkedList) + + +@app.post("/api/v1/chat/completions") +async def create_chat_completion(request: APIChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + if request.repetition_penalty is not None: + gen_params["repetition_penalty"] = request.repetition_penalty + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.model_validate(content["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +### END GENERAL API - NOT OPENAI COMPATIBLE ### + + +def create_openai_api_server(): + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + app_settings.api_keys = args.api_keys + + logger.info(f"args: {args}") + return args + + +if __name__ == "__main__": + args = create_openai_api_server() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/register_worker.py b/fastchat/serve/register_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..aa57117b9106730b0731df10dbfb0a2b1bbe381b --- /dev/null +++ b/fastchat/serve/register_worker.py @@ -0,0 +1,28 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + parser.add_argument("--multimodal", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + "multimodal": args.multimodal, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/fastchat/serve/remote_logger.py b/fastchat/serve/remote_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..549d49811848d83da27eb117d63b2e9e60e59c3f --- /dev/null +++ b/fastchat/serve/remote_logger.py @@ -0,0 +1,59 @@ +# A JSON logger that sends data to remote endpoint. +# Architecturally, it hosts a background thread that sends logs to a remote endpoint. +import os +import json +import requests +import threading +import queue +import logging + +_global_logger = None + + +def get_remote_logger(): + global _global_logger + if _global_logger is None: + if url := os.environ.get("REMOTE_LOGGER_URL"): + logging.info(f"Remote logger enabled, sending data to {url}") + _global_logger = RemoteLogger(url=url) + else: + _global_logger = EmptyLogger() + return _global_logger + + +class EmptyLogger: + """Dummy logger that does nothing.""" + + def __init__(self): + pass + + def log(self, _data: dict): + pass + + +class RemoteLogger: + """A JSON logger that sends data to remote endpoint.""" + + def __init__(self, url: str): + self.url = url + + self.logs = queue.Queue() + self.thread = threading.Thread(target=self._send_logs, daemon=True) + self.thread.start() + + def log(self, data: dict): + self.logs.put_nowait(data) + + def _send_logs(self): + while True: + data = self.logs.get() + + # process the data by keep only the top level fields, and turn any nested dict into a string + for key, value in data.items(): + if isinstance(value, (dict, list, tuple)): + data[key] = json.dumps(value, ensure_ascii=False) + + try: + requests.post(self.url, json=data) + except Exception: + logging.exception("Failed to send logs to remote endpoint") diff --git a/fastchat/serve/sglang_worker.py b/fastchat/serve/sglang_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9b607709e79e5bc29a4ecf08d3236795a4eded --- /dev/null +++ b/fastchat/serve/sglang_worker.py @@ -0,0 +1,312 @@ +""" +A model worker that executes the model based on SGLang. + +Usage: +python3 -m fastchat.serve.sglang_worker --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 --worker-address http://localhost:30000 +""" + +import argparse +import asyncio +import json +import multiprocessing +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer, get_config +from sglang.srt.utils import load_image, is_multimodal_model + +from fastchat.conversation import IMAGE_PLACEHOLDER_STR +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +app = FastAPI() + + +@sgl.function +def pipeline(s, prompt, max_tokens): + for p in prompt: + if isinstance(p, str): + s += p + else: + s += sgl.image(p) + s += sgl.gen("response", max_tokens=max_tokens) + + +class SGLWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + tokenizer_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: str, + runtime: sgl.Runtime, + trust_remote_code: bool, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + is_multimodal_model(model_path), + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: SGLang worker..." + ) + + self.tokenizer = get_tokenizer(tokenizer_path) + self.context_len = get_context_length( + get_config(model_path, trust_remote_code=trust_remote_code) + ) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + prompt = params.pop("prompt") + images = params.get("images", []) + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + presence_penalty = float(params.get("presence_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + echo = params.get("echo", True) + + # Handle stop_str + stop = [] + if isinstance(stop_str, str) and stop_str != "": + stop.append(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.extend(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.append(s) + + # make sampling params for sgl.gen + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + # split prompt by image token + split_prompt = prompt.split(IMAGE_PLACEHOLDER_STR) + if prompt.count(IMAGE_PLACEHOLDER_STR) != len(images): + raise ValueError( + "The number of images passed in does not match the number of tokens in the prompt!" + ) + prompt = [] + for i in range(len(split_prompt)): + prompt.append(split_prompt[i]) + if i < len(images): + prompt[-1] = prompt[-1].strip() + prompt.append(load_image(images[i])) + + state = pipeline.run( + prompt, + max_new_tokens, + stop=stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stream=True, + ) + + entire_output = prompt if echo else "" + async for out, meta_info in state.text_async_iter( + var_name="response", return_meta_data=True + ): + partial_stop = any(is_partial_stop(out, i) for i in stop) + + # prevent yielding partial stop sequence + if partial_stop: + continue + + entire_output += out + prompt_tokens = meta_info["prompt_tokens"] + completion_tokens = meta_info["completion_tokens"] + + ret = { + "text": entire_output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "error_code": 0, + } + yield ret + + async def generate_stream_gate(self, params): + try: + async for ret in self.generate_stream(params): + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + async def generate_gate(self, params): + async for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument("--tokenizer-path", type=str, default="") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust-remote-code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + parser.add_argument( + "--multimodal", + action="store_true", + required=False, + default=False, + help="Register this worker as serving a multimodal model.", + ) + + args = parser.parse_args() + + args.tp_size = args.num_gpus if args.num_gpus > 1 else 1 + args.tokenizer_path = ( + args.model_path if args.tokenizer_path == "" else args.tokenizer_path + ) + + multiprocessing.set_start_method("spawn", force=True) + runtime = sgl.Runtime( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + trust_remote_code=args.trust_remote_code, + mem_fraction_static=args.mem_fraction_static, + tp_size=args.tp_size, + ) + sgl.set_default_backend(runtime) + + worker = SGLWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.tokenizer_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + args.conv_template, + runtime, + args.trust_remote_code, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/serve/shutdown_serve.py b/fastchat/serve/shutdown_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..95e2b704f0b65584c5be15ce14b40bc150bd6009 --- /dev/null +++ b/fastchat/serve/shutdown_serve.py @@ -0,0 +1,24 @@ +""" +Usage: +python shutdown_serve.py --down all +options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers +""" + +import argparse +import os +import subprocess + +parser = argparse.ArgumentParser() +parser.add_argument( + "--down", choices=["all", "controller", "model_worker", "openai_api_server"] +) +args = parser.parse_args() +base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" +if args.down == "all": + shell_script = base_shell.format("") +else: + serve = f".{args.down}" + shell_script = base_shell.format(serve) +print(f"execute shell cmd: {shell_script}") +subprocess.run(shell_script, shell=True, check=True) +print(f"{args.down} has been shutdown!") diff --git a/fastchat/serve/test_message.py b/fastchat/serve/test_message.py new file mode 100644 index 0000000000000000000000000000000000000000..203a44901c10c5526f198c8e9dbb4e32d15ed7aa --- /dev/null +++ b/fastchat/serve/test_message.py @@ -0,0 +1,81 @@ +"""Send a test message.""" +import argparse +import json + +import requests + +from fastchat.model.model_adapter import get_conversation_template + + +def main(): + model_name = args.model_name + + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + print(f"No available workers for {model_name}") + return + + conv = get_conversation_template(model_name) + conv.append_message(conv.roles[0], args.message) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + headers = {"User-Agent": "FastChat Client"} + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + ) + + print(f"{conv.roles[0]}: {args.message}") + print(f"{conv.roles[1]}: ", end="") + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, required=True) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument( + "--message", type=str, default="Tell me a story with more than 1000 words." + ) + args = parser.parse_args() + + main() diff --git a/fastchat/serve/test_throughput.py b/fastchat/serve/test_throughput.py new file mode 100644 index 0000000000000000000000000000000000000000..3796a6e2a7cb53dc6921674fc4c488246e0b93c7 --- /dev/null +++ b/fastchat/serve/test_throughput.py @@ -0,0 +1,115 @@ +"""Benchmarking script to test the throughput of serving workers.""" +import argparse +import json + +import requests +import threading +import time + +from fastchat.conversation import get_conv_template + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") + prompt_template = conv.get_prompt() + prompts = [prompt_template for _ in range(args.n_thread)] + + headers = {"User-Agent": "fastchat Client"} + ploads = [ + { + "model": args.model_name, + "prompt": prompts[i], + "max_new_tokens": args.max_new_tokens, + "temperature": 0.0, + # "stop": conv.sep, + } + for i in range(len(prompts)) + ] + + def send_request(results, i): + if args.test_dispatch: + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + thread_worker_addr = ret.json()["address"] + else: + thread_worker_addr = worker_addr + print(f"thread {i} goes to {thread_worker_addr}") + response = requests.post( + thread_worker_addr + "/worker_generate_stream", + headers=headers, + json=ploads[i], + stream=False, + ) + k = list( + response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") + ) + # print(k) + response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + error_code = json.loads(k[-2].decode("utf-8"))["error_code"] + # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") + results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + + # use N threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_thread + for i in range(args.n_thread): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + # time.sleep(0.5) + threads.append(t) + + for t in threads: + t.join() + + print(f"Time (POST): {time.time() - tik} s") + # n_words = 0 + # for i, response in enumerate(results): + # # print(prompt[i].replace(conv.sep, "\n"), end="") + # # make sure the streaming finishes at EOS or stopping criteria + # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + # # print(response_new_words) + # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + n_words = sum(results) + time_seconds = time.time() - tik + print( + f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " + f"throughput: {n_words / time_seconds} words/s." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="vicuna") + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--n-thread", type=int, default=8) + parser.add_argument("--test-dispatch", action="store_true") + args = parser.parse_args() + + main() diff --git a/fastchat/serve/vision/create_vqa_examples_dir.py b/fastchat/serve/vision/create_vqa_examples_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2630bd985cba2c02f08ec59d1fb237d5d47545 --- /dev/null +++ b/fastchat/serve/vision/create_vqa_examples_dir.py @@ -0,0 +1,127 @@ +import datasets +from datasets import load_dataset +from PIL import Image +from pathlib import Path +import pandas as pd +import os +import json +import tqdm +import argparse +import shutil +import numpy as np + +np.random.seed(0) + +""" +Creates a directory with images and JSON files for VQA examples. Final json is located in metadata_sampled.json +""" + + +def download_images_and_create_json( + dataset_info, cache_dir="~/vqa_examples_cache", base_dir="./vqa_examples" +): + for dataset_name, info in dataset_info.items(): + dataset_cache_dir = os.path.join(cache_dir, dataset_name) + os.makedirs(dataset_cache_dir, exist_ok=True) + + if info["subset"]: + dataset = load_dataset( + info["path"], + info["subset"], + cache_dir=dataset_cache_dir, + split=info["split"], + ) + else: + dataset = load_dataset( + info["path"], cache_dir=dataset_cache_dir, split=info["split"] + ) + dataset_dir = os.path.join(base_dir, dataset_name) + os.makedirs(dataset_dir, exist_ok=True) + + json_data = [] + for i, item in enumerate(tqdm.tqdm(dataset)): + id_key = i if info["id_key"] == "index" else item[info["id_key"]] + image_pil = item[info["image_key"]].convert("RGB") + image_path = os.path.join(dataset_dir, f"{id_key}.jpg") + image_pil.save(image_path) + json_entry = { + "dataset": dataset_name, + "question": item[info["question_key"]], + "path": image_path, + } + json_data.append(json_entry) + + with open(os.path.join(dataset_dir, "data.json"), "w") as json_file: + json.dump(json_data, json_file, indent=4) + # Delete the cache directory for the dataset + shutil.rmtree(dataset_cache_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, default="~/.cache") + parser.add_argument("--output_dir", type=str, default="./vqa_examples") + args = parser.parse_args() + + datasets_info = { + "DocVQA": { + "path": "lmms-lab/DocVQA", + "image_key": "image", + "question_key": "question", + "id_key": "questionId", + "subset": "DocVQA", + "split": "test", + }, + "ChartQA": { + "path": "HuggingFaceM4/ChartQA", + "image_key": "image", + "question_key": "query", + "id_key": "index", + "subset": False, + "split": "test", + }, + "realworldqa": { + "path": "visheratin/realworldqa", + "image_key": "image", + "question_key": "question", + "id_key": "index", + "subset": False, + "split": "test", + }, + "NewYorker": { + "path": "jmhessel/newyorker_caption_contest", + "image_key": "image", + "question_key": "questions", + "id_key": "index", + "subset": "explanation", + "split": "train", + }, + "WikiArt": { + "path": "huggan/wikiart", + "image_key": "image", + "question_key": "artist", + "id_key": "index", + "subset": False, + "split": "train", + }, + "TextVQA": { + "path": "facebook/textvqa", + "image_key": "image", + "question_key": "question", + "id_key": "question_id", + "subset": False, + "split": "train", + }, + } + + download_images_and_create_json( + datasets_info, cache_dir=args.data_dir, base_dir=args.output_dir + ) + dataset_json = [] + for dataset_name in datasets_info.keys(): + with open(f"{args.output_dir}/{dataset_name}/data.json") as f: + data = json.load(f) + dataset_json.extend(np.random.choice(data, 500)) + + with open(f"{args.output_dir}/metadata_sampled.json", "w") as f: + json.dump(dataset_json, f, indent=4) diff --git a/fastchat/serve/vision/create_vqa_examples_json.py b/fastchat/serve/vision/create_vqa_examples_json.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e549390bb77d8969b703a382a9b13c43fc0376 --- /dev/null +++ b/fastchat/serve/vision/create_vqa_examples_json.py @@ -0,0 +1,34 @@ +""" +Changes proportion of examples in metadata_sampled.json + +Usage: + +python3 -m fastchat.serve.vision.create_vqa_examples_json +""" + +import json +import argparse +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, default="~/.cache") + parser.add_argument("--output_dir", type=str, default="./vqa_examples") + args = parser.parse_args() + + dataset_prop = { + "DocVQA": 500, + "ChartQA": 500, + "NewYorker": 1000, + "WikiArt": 500, + "TextVQA": 500, + } + + dataset_json = [] + for dataset_name in dataset_prop.keys(): + with open(f"{args.output_dir}/{dataset_name}/data.json") as f: + data = json.load(f) + dataset_json.extend(np.random.choice(data, dataset_prop[dataset_name])) + + with open(f"{args.output_dir}/metadata_sampled.json", "w") as f: + json.dump(dataset_json, f, indent=4) diff --git a/fastchat/serve/vision/image.py b/fastchat/serve/vision/image.py new file mode 100644 index 0000000000000000000000000000000000000000..58bedf8f07b976ca4b2b6866f49eae2483dd521d --- /dev/null +++ b/fastchat/serve/vision/image.py @@ -0,0 +1,135 @@ +import base64 +from enum import auto, IntEnum +from io import BytesIO + +from pydantic import BaseModel + + +class ImageFormat(IntEnum): + """Image formats.""" + + URL = auto() + LOCAL_FILEPATH = auto() + PIL_IMAGE = auto() + BYTES = auto() + DEFAULT = auto() + + +class Image(BaseModel): + url: str = "" + filetype: str = "" + image_format: ImageFormat = ImageFormat.BYTES + base64_str: str = "" + + def convert_image_to_base64(self): + """Given an image, return the base64 encoded image string.""" + from PIL import Image + import requests + + # Load image if it has not been loaded in yet + if self.image_format == ImageFormat.URL: + response = requests.get(image) + image = Image.open(BytesIO(response.content)).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + elif self.image_format == ImageFormat.LOCAL_FILEPATH: + image = Image.open(self.url).convert("RGBA") + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + elif self.image_format == ImageFormat.BYTES: + image_bytes = image + + img_b64_str = base64.b64encode(image_bytes).decode() + + return img_b64_str + + def to_openai_image_format(self): + if self.image_format == ImageFormat.URL: # input is a url + return self.url + elif self.image_format == ImageFormat.LOCAL_FILEPATH: # input is a local image + self.base64_str = self.convert_image_to_base64(self.url) + return f"data:image/{self.filetype};base64,{self.base64_str}" + elif self.image_format == ImageFormat.BYTES: + return f"data:image/{self.filetype};base64,{self.base64_str}" + else: + raise ValueError( + f"This file is not valid or not currently supported by the OpenAI API: {self.url}" + ) + + def resize_image_and_return_image_in_bytes(self, image, max_image_size_mb): + import math + + image_format = "png" + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 1024, 1024 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if longest_edge != max(image.size): + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + if max_image_size_mb: + target_size_bytes = max_image_size_mb * 1024 * 1024 + + current_size_bytes = image_bytes.tell() + if current_size_bytes > target_size_bytes: + resize_factor = (target_size_bytes / current_size_bytes) ** 0.5 + new_width = math.floor(image.width * resize_factor) + new_height = math.floor(image.height * resize_factor) + image = image.resize((new_width, new_height)) + + image_bytes = BytesIO() + image.save(image_bytes, format="PNG") + current_size_bytes = image_bytes.tell() + + image_bytes.seek(0) + + return image_format, image_bytes + + def convert_url_to_image_bytes(self, max_image_size_mb): + from PIL import Image + + if self.url.endswith(".svg"): + import cairosvg + + with open(self.url, "rb") as svg_file: + svg_data = svg_file.read() + + png_data = cairosvg.svg2png(bytestring=svg_data) + pil_image = Image.open(BytesIO(png_data)).convert("RGBA") + else: + pil_image = Image.open(self.url).convert("RGBA") + + image_format, image_bytes = self.resize_image_and_return_image_in_bytes( + pil_image, max_image_size_mb + ) + + img_base64_str = base64.b64encode(image_bytes.getvalue()).decode() + + return image_format, img_base64_str + + def to_conversation_format(self, max_image_size_mb): + image_format, image_bytes = self.convert_url_to_image_bytes( + max_image_size_mb=max_image_size_mb + ) + + self.filetype = image_format + self.image_format = ImageFormat.BYTES + self.base64_str = image_bytes + + return self + + +if __name__ == "__main__": + image = Image(url="fastchat/serve/example_images/fridge.jpg") + image.to_conversation_format(max_image_size_mb=5 / 1.5) + + json_str = image.model_dump_json() + print(json_str) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..0af680bb5f1cabcf77b9a67bd9d542b53550f89f --- /dev/null +++ b/fastchat/serve/vllm_worker.py @@ -0,0 +1,302 @@ +""" +A model worker that executes the model based on vLLM. + +See documentations at docs/vllm_integration.md +""" + +import argparse +import asyncio +import json +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from vllm import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + + +app = FastAPI() + + +class VLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: AsyncLLMEngine, + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." + ) + self.tokenizer = llm_engine.engine.tokenizer + # This is to support vllm >= 0.2.7 where TokenizerGroup was introduced + # and llm_engine.engine.tokenizer was no longer a raw tokenizer + if hasattr(self.tokenizer, "tokenizer"): + self.tokenizer = llm_engine.engine.tokenizer.tokenizer + self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + request = params.get("request", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop=list(stop), + stop_token_ids=stop_token_ids, + max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + best_of=best_of, + ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + aborted = False + if request and await request.is_disconnected(): + await engine.abort(request_id) + request_output.finished = True + aborted = True + for output in request_output.outputs: + output.finish_reason = "abort" + + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + ret = { + "text": text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": [ + output.cumulative_logprob for output in request_output.outputs + ], + "finish_reason": request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs], + } + # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. + # This aligns with the behavior of model_worker. + if request_output.finished: + yield (json.dumps({**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + if aborted: + break + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await engine.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + params["request"] = request + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + params["request"] = request + output = await worker.generate(params) + release_worker_semaphore() + await engine.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--gpu_memory_utilization", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.model_path: + args.model = args.model_path + if args.num_gpus > 1: + args.tensor_parallel_size = args.num_gpus + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + worker = VLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + engine, + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/fastchat/train/llama2_flash_attn_monkey_patch.py b/fastchat/train/llama2_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..c1fe51c91bd553f8fbe0c25e9c88fe1abc3542e7 --- /dev/null +++ b/fastchat/train/llama2_flash_attn_monkey_patch.py @@ -0,0 +1,238 @@ +import warnings +from typing import Optional, Tuple + +import torch +from flash_attn import __version__ as flash_attn_version +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_kvpacked_func, +) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaModel, + rotate_half, +) + + +def apply_rotary_pos_emb(q, k, cos_sin, position_ids): + gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] + gather_indices = gather_indices.repeat( + 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] + ) + bsz = gather_indices.shape[0] + cos, sin = ( + torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) + for x in cos_sin + ) + q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) + return q, k + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + kv_heads = getattr(self, "num_key_value_heads", self.num_heads) + + q, k, v = ( + op(hidden_states).view(bsz, q_len, nh, self.head_dim) + for op, nh in ( + (self.q_proj, self.num_heads), + (self.k_proj, kv_heads), + (self.v_proj, kv_heads), + ) + ) + # shape: (b, s, num_heads, head_dim) + + kv_seq_len = k.shape[1] + past_kv_len = 0 + if past_key_value is not None: + past_kv_len = past_key_value[0].shape[2] + kv_seq_len += past_kv_len + + cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) + + if past_key_value is not None: + assert ( + flash_attn_version >= "2.1.0" + ), "past_key_value support requires flash-attn >= 2.1.0" + # reuse k, v + k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) + v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) + + past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None + + if attention_mask is None: + output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( + bsz, q_len, -1 + ) + else: + q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) + # We can skip concat and call unpad twice but seems better to call unpad only once. + kv, _, cu_k_lens, max_k = unpad_input( + torch.stack((k, v), dim=2), attention_mask + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_q_lens, + cu_k_lens, + max_s, + max_k, + 0.0, + softmax_scale=None, + causal=True, + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as flash attention +# takes a boolean key_padding_mask. Fills in the past kv length for use in forward. +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + ( + torch.full( + (input_shape[0], past_key_values_length), + True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) + + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples + + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + + LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + LlamaAttention.forward = forward + + +def test(): + from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward + from transformers.models.llama.configuration_llama import LlamaConfig + + config = LlamaConfig( + hidden_size=1024, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=8, + max_position_embeddings=16, + ) + device = torch.device("cuda") + model = LlamaModel(config) + attn = LlamaAttention(config).to(device).half() + bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings + position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( + -1, seqlen + ) + + mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) + for i in range(4): + hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) + if i: + mask[0, -i:] = False + mask[1, :i] = False + + lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) + ref, _, _ = attn.forward( + hidden, attention_mask=lmask, position_ids=position_ids + ) + + fast, _, _ = fastchat_forward( + attn, hidden, attention_mask=mask, position_ids=position_ids + ) + + lmask = _prepare_decoder_attention_mask( + model, mask, hidden.shape[:2], hidden, 0 + ) + test, _, _ = forward( + attn, hidden, attention_mask=lmask, position_ids=position_ids + ) + + print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}") + print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}") + print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}") + print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}") + print(f"allclose(fast, test) = {torch.allclose(fast, test)}") + + with torch.no_grad(): + # Also check that past_kv is handled properly + hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) + part_len = seqlen // 4 + assert part_len * 4 == seqlen + mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) + mask[0, -2:] = False + lmask = _prepare_decoder_attention_mask( + model, mask, hidden.shape[:2], hidden, 0 + ) + oneshot, _, _ = forward( + attn, hidden, attention_mask=lmask, position_ids=position_ids + ) + parts = [] + past_kv, past_kv_len = None, 0 + for i in range(4): + start = part_len * i + end = start + part_len + hidden_part = hidden[:, start:end, ...] + lmask = _prepare_decoder_attention_mask( + model, + mask[:, start:end], + hidden_part.shape[:2], + hidden_part, + past_kv_len, + ) + part, _, past_kv = forward( + attn, + hidden_part.clone(), + attention_mask=lmask, + position_ids=position_ids[:, start:end], + past_key_value=past_kv, + use_cache=True, + ) + parts.append(part) + past_kv_len = past_kv[0].shape[2] + + print( + f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}" + ) + print( + f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}" + ) + + +if __name__ == "__main__": + test() diff --git a/fastchat/train/llama_flash_attn_monkey_patch.py b/fastchat/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b64aa8181726c26c9b3da355e17a6afb163f7796 --- /dev/null +++ b/fastchat/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,107 @@ +from typing import Optional, Tuple +import warnings + +import torch +from torch import nn +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + max_s = q_len + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/fastchat/train/llama_xformers_attn_monkey_patch.py b/fastchat/train/llama_xformers_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..f8351e41ccd4a64dca237bd8f8be0702b23989dc --- /dev/null +++ b/fastchat/train/llama_xformers_attn_monkey_patch.py @@ -0,0 +1,129 @@ +""" +Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments +""" + +import logging +import math +from typing import Optional, Tuple + +import torch +import transformers.models.llama.modeling_llama +from torch import nn + +try: + import xformers.ops +except ImportError: + logging.error("xformers not found! Please install it before trying to use it.") + + +def replace_llama_attn_with_xformers_attn(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + ) + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value diff --git a/fastchat/train/train.py b/fastchat/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8198e3861856b1f5f7020052db64b2dae29396 --- /dev/null +++ b/fastchat/train/train.py @@ -0,0 +1,318 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import pathlib +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" + }, + ) + padding_side: str = field( + default="right", metadata={"help": "The padding side in tokenizer"} + ) + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def trainer_save_model_safe(trainer: transformers.Trainer): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType, FullStateDictConfig + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + trainer.model, StateDictType.FULL_STATE_DICT, save_policy + ): + trainer.save_model() + + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + conv = get_conversation_template("vicuna") + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(turns): + if turn == "": + break + turn_len = len(tokenizer(turn).input_ids) + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + sources = [example["conversations"] for example in raw_data] + data_dict = preprocess(sources, tokenizer) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = dataset_cls(train_json, tokenizer=tokenizer) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer) + else: + eval_dataset = None + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + + # Set RoPE scaling factor + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + + # Load model and tokenizer + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side=model_args.padding_side, + use_fast=False, + trust_remote_code=model_args.trust_remote_code, + ) + + if tokenizer.pad_token != tokenizer.unk_token: + tokenizer.pad_token = tokenizer.unk_token + + # Load data + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + # Start trainner + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + # Save model + model.config.use_cache = True + trainer.save_state() + if trainer.is_deepspeed_enabled: + trainer.save_model() + else: + trainer_save_model_safe(trainer) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_baichuan.py b/fastchat/train/train_baichuan.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b19b486149c6d03859ff4750b07819a950f4b6 --- /dev/null +++ b/fastchat/train/train_baichuan.py @@ -0,0 +1,333 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import jsonlines +import pathlib +from multiprocessing import Pool +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def apply_prompt_template(sources, systems=None): + conv = get_conversation_template("vicuna") + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + if systems and systems[i]: + conv.set_system_message(systems[i]) + prompt = conv.get_prompt() + conversations.append(prompt) + return conversations, conv + + +def tokenize_conversations(conversations, tokenizer): + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + return input_ids, targets + + +def mask_targets(conversations, targets, tokenizer, conv): + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep2) + cur_len = 0 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(turns): + if turn == "": + break + turn_len = len(tokenizer(turn + conv.sep2).input_ids) + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return targets + + +def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, **kwargs) -> Dict: + systems = None if not kwargs else kwargs.get("systems", None) + + # If the data volume is small, process it directly in the main thread + if len(sources) <= 1000: + conversations, conv = apply_prompt_template(sources, systems) + input_ids, targets = tokenize_conversations(conversations, tokenizer) + targets = mask_targets(conversations, targets, tokenizer, conv) + else: # If the data volume is large, use multithreading for processing + with Pool() as p: + conversations, conv = p.apply_async( + apply_prompt_template, (sources, systems) + ).get() + input_ids, targets = p.apply_async( + tokenize_conversations, (conversations, tokenizer) + ).get() + targets = p.apply_async( + mask_targets, (conversations, targets, tokenizer, conv) + ).get() + p.close() + p.join() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + systems = [example.get("system", "") for example in raw_data] + sources = [example["conversations"] for example in raw_data] + + data_dict = preprocess(sources, tokenizer, systems=systems) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess( + [self.raw_data[i]["conversations"]], + self.tokenizer, + systems=[self.raw_data[i].get("system", "")], + ) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args, train_ratio=0.98 +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_ratio = min(train_ratio, 1.0) + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + data_path = data_args.data_path + if data_path.endswith(".json"): + raw_data = json.load(open(data_path, "r")) + elif data_path.endswith(".jsonl"): + with jsonlines.open(data_path, mode="r") as reader: + raw_data = [item for item in reader] + + # Split train/test + np.random.seed(0) + perm = np.random.permutation(len(raw_data)) + split = int(len(perm) * train_ratio) + train_indices = perm[:split] + if train_ratio < 1: + eval_indices = perm[split:] + else: + # if train_ratio==1, we use 5% of data as eval data, make sure trainer will not throw error when eval data is empty + eval_indices = perm[-int(len(perm) * 0.05) :] + train_raw_data = [raw_data[i] for i in train_indices] + eval_raw_data = [raw_data[i] for i in eval_indices] + rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") + + train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer) + eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Set RoPE scaling factor + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Tie the weights + model.tie_weights() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + # NOTE: if the token_id exceed the vocab_size will cause failing in training process! we need add special config and resize the embedding size! + tokenizer.pad_token = tokenizer.unk_token + print(f"tokens len: {len(tokenizer)}") + model.resize_token_embeddings(len(tokenizer)) + + data_module = make_supervised_data_module( + tokenizer=tokenizer, train_ratio=0.98, data_args=data_args + ) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_flant5.py b/fastchat/train/train_flant5.py new file mode 100644 index 0000000000000000000000000000000000000000..688c2f4fa33ec50b5daab43b62e984b2aced1c68 --- /dev/null +++ b/fastchat/train/train_flant5.py @@ -0,0 +1,436 @@ +# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import copy +import os +from dataclasses import dataclass, field +import random +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence + +import torch +import torch.distributed as dist + +import transformers +from torch.utils.data import Dataset +from transformers import Trainer, AddedToken + +from fastchat.model.model_adapter import get_conversation_template + +default_conversation = get_conversation_template("t5") + +# TODO: import and use code from ../data/dataset.py + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + num_data: int = -1 + preprocessed_path: str = field( + default=None, metadata={"help": "Path to the preprocessed training data."} + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + other_tokens, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + for new_token in other_tokens: + num_new_tokens += tokenizer.add_tokens(AddedToken(new_token, normalized=False)) + + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn( + strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer +) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _form_qa( + q_list, + a_list, + tokenized_conversation, + tokenized_lens, + speakers, + header_len, + max_length, + eos_id, +): + cur_idx = header_len + conv_len = len(tokenized_conversation) + + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if cur_idx >= conv_len: + break + if speaker == "gpt": + # truncate answer if it is too long + content_a = None + if tokenized_len > max_length: + content_a = tokenized_conversation[cur_idx : cur_idx + max_length] + else: + content_a = tokenized_conversation[cur_idx : cur_idx + tokenized_len] + content_a.append(eos_id) + a_list.append(content_a) + content_q = None + if cur_idx >= max_length: + content_q = tokenized_conversation[cur_idx - max_length : cur_idx] + else: + content_q = tokenized_conversation[:cur_idx] + content_q.append(eos_id) + q_list.append(content_q) + # asser the last token is actually a EOS for an answer + assert a_list[-1][-1] == eos_id, "Last Token is not EOS!" + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + + unknown_role = "unknown" # use default unknown role + roles = { + "human": default_conversation.roles[0], # human role + "gpt": default_conversation.roles[1], # gpt role + } + + for i in range(len(source)): + sentence = source[i] + sentence_from = sentence["from"].lower() + + # TODO(Dacheng): verify this is a good way to split sentences + if sentence_from == "human": + # if this is not the last sentence + if i != len(source) - 1: + next_sentence = source[i + 1] + sentence["value"] = ( + BEGIN_SIGNAL + + roles.get(sentence_from, unknown_role) + + ": " + + sentence["value"] + + END_SIGNAL + + BEGIN_SIGNAL + + roles.get(next_sentence["from"].lower(), unknown_role) + + ": " + ) + else: + # if human is the last speaker, it does not contribute to an answer + pass + else: + sentence["value"] = sentence["value"] + END_SIGNAL + if get_conversation: + conversation += sentence["value"] + + return conversation + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + # add end signal and concatenate together + conversations = [] + header = f"{default_conversation.system_message}\n\n" + for source in sources: + conversation = _add_speaker_and_signal(header, source, tokenizer) + conversations.append(conversation) + # TODO(Dacheng): This is related to whether the dataset has been truncated.. + # Assume we get long conversations, don't pad, don't return tensor + tokenized_conversations = tokenizer(conversations, max_length=None)["input_ids"] + q_list = [] + a_list = [] + # count for EOS length + header_len = _tokenize_fn([header], tokenizer)["input_ids_lens"][0] - 1 + from tqdm import tqdm + + for tokenized_conversation, source in tqdm(zip(tokenized_conversations, sources)): + tokenized_sentence = _tokenize_fn([s["value"] for s in source], tokenizer) + tokenized_lens = tokenized_sentence["input_ids_lens"] + tokenized_lens = [l - 1 for l in tokenized_lens] + speakers = [sentence["from"] for sentence in source] + ids = tokenized_sentence["input_ids"] + _form_qa( + q_list, + a_list, + tokenized_conversation, + tokenized_lens, + speakers, + header_len, + tokenizer.model_max_length, + tokenizer.eos_token_id, + ) + return dict(input_ids=q_list, labels=a_list) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + preprocessed_path, + num_data, + ): + super(SupervisedDataset, self).__init__() + + # save to file + # Make sure only the first process is processing the dataset + if dist.get_rank() != 0: + dist.barrier() + self.preprocessed_path = preprocessed_path + if os.path.exists(self.preprocessed_path): + logging.warning("loading from preprocessed data") + with open(self.preprocessed_path, "r") as f: + data_dict = json.load(f) + if dist.get_rank() == 0: + dist.barrier() + else: + if not os.path.exists("preprocessed_data"): + os.mkdir("preprocessed_data") + assert dist.get_rank() == 0, "Only the first process should process" + logging.warning("Loading data...") + list_data_dict = json.load(open(data_path, "r")) + + logging.warning("Formatting inputs...") + sources = [] + + sources = [example["conversations"] for example in list_data_dict] + + data_dict = preprocess(sources, tokenizer) + json_data_dict = json.dumps(data_dict) + + # Remember to close file to avoid concurrent r/w + with open(self.preprocessed_path, "w") as f: + f.write(json_data_dict) + + # Release barrier + dist.barrier() + + if num_data != -1: + data_dict["input_ids"] = data_dict["input_ids"][:num_data] + data_dict["labels"] = data_dict["labels"][:num_data] + + # Shuffle data to see more conversations, if only train on partial data + temp = list(zip(data_dict["input_ids"], data_dict["labels"])) + random.shuffle(temp) + res1, res2 = zip(*temp) + data_dict["input_ids"], data_dict["labels"] = list(res1), list(res2) + + # Dacheng: Get rid of short QA pair + self.input_ids = copy.deepcopy(data_dict["input_ids"]) + self.labels = copy.deepcopy(data_dict["labels"]) + length_arr = defaultdict(int) + for idx, (input, label) in enumerate( + zip(data_dict["input_ids"], data_dict["labels"]) + ): + length_arr[str(len(label) // 100)] += 1 + if len(input) <= 5: + del_idx = self.input_ids.index(input) + self.input_ids.pop(del_idx) + self.labels.pop(del_idx) + if len(label) <= 5: + del_idx = self.labels.index(label) + self.input_ids.pop(del_idx) + self.labels.pop(del_idx) + + for input, label in zip(self.input_ids, self.labels): + assert len(input) >= 5 + assert len(label) >= 5 + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [ + torch.as_tensor(instance[key], dtype=torch.int64) + for instance in instances + ] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + ret = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + torch.set_printoptions(profile="full") + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = SupervisedDataset + train_dataset = dataset_cls( + tokenizer=tokenizer, + data_path=data_args.data_path, + preprocessed_path=data_args.preprocessed_path, + num_data=data_args.num_data, + ) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict( + train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator + ) + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + model = transformers.AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + ) + # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend + # a space before special tokens. + tokenizer = transformers.T5Tokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + tokenizer=tokenizer, + model=model, + ) + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_lora.py b/fastchat/train/train_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecb47c29fbb21f6e57d9de1cba70002a886d152 --- /dev/null +++ b/fastchat/train/train_lora.py @@ -0,0 +1,222 @@ +# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> + +# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import logging +import pathlib +import typing +import os + +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +import transformers +from transformers import Trainer, BitsAndBytesConfig, deepspeed +import torch + +from fastchat.train.train import ( + DataArguments, + ModelArguments, + make_supervised_data_module, +) + +from fastchat.train.llama_flash_attn_monkey_patch import ( + replace_llama_attn_with_flash_attn, +) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: typing.Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + flash_attn: bool = False + + +@dataclass +class LoraArguments: + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: typing.List[str] = field( + default_factory=lambda: ["q_proj", "v_proj"] + ) + lora_weight_path: str = "" + lora_bias: str = "none" + q_lora: bool = False + + +def maybe_zero_3(param): + if hasattr(param, "ds_id"): + assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} + return to_return + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + if training_args.flash_attn: + replace_llama_attn_with_flash_attn() + + device_map = None + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if lora_args.q_lora: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + logging.warning( + "FSDP and ZeRO3 are both currently incompatible with QLoRA." + ) + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + device_map=device_map, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + if lora_args.q_lora + else None, + ) + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type="CAUSAL_LM", + ) + + if lora_args.q_lora: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) + if not ddp and torch.cuda.device_count() > 1: + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + model = get_peft_model(model, lora_config) + if training_args.flash_attn: + for name, module in model.named_modules(): + if "norm" in name: + module = module.to(compute_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module = module.to(compute_dtype) + if training_args.deepspeed is not None and training_args.local_rank == 0: + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + model.enable_input_require_grads() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + model.config.use_cache = False + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + # check if zero3 mode enabled + if deepspeed.is_deepspeed_zero3_enabled(): + # use deepspeed engine internal function to gather state dict + # state_dict_zero3 contains whole parameters of base and lora adapters + # we will not extract lora parameters since peft save_pretrained will do that + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 + state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + if training_args.local_rank == 0: + state_dict = state_dict_zero3 + else: + # in other mode we use original code from fastchat team, to make sure our change is minimum + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), lora_args.lora_bias + ) + + if training_args.local_rank == 0: + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_lora_t5.py b/fastchat/train/train_lora_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..21abc92cb9e64482d4c3375e6321bd00641ac4f9 --- /dev/null +++ b/fastchat/train/train_lora_t5.py @@ -0,0 +1,226 @@ +# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import copy +import os +from dataclasses import dataclass, field +import random +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch +import torch.distributed as dist + + +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType + +import transformers +from torch.utils.data import Dataset +from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed + +from fastchat.train.train_flant5 import ( + smart_tokenizer_and_embedding_resize, + make_supervised_data_module, +) + +from fastchat.train.train_lora import get_peft_state_maybe_zero_3 + +from fastchat.model.model_adapter import get_conversation_template + +default_conversation = get_conversation_template("t5") + +# TODO: import and use code from ../data/dataset.py + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +@dataclass +class LoraArguments: + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"]) + lora_weight_path: str = "" + lora_bias: str = "none" + q_lora: bool = False + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + num_data: int = -1 + preprocessed_path: str = field( + default=None, metadata={"help": "Path to the preprocessed training data."} + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def safe_save_model_for_hf_trainer( + trainer: transformers.Trainer, output_dir: str, state_dict: dict +): + """Collects the state dict and dump to disk.""" + + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + device_map = None + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if lora_args.q_lora: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + logging.warning( + "FSDP and ZeRO3 are both currently incompatible with QLoRA." + ) + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + model = transformers.AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + device_map=device_map, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + if lora_args.q_lora + else None, + ) + + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type=TaskType.SEQ_2_SEQ_LM, + ) + + if lora_args.q_lora: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) + if not ddp and torch.cuda.device_count() > 1: + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + model = get_peft_model(model, lora_config) + if training_args.deepspeed is not None and training_args.local_rank == 0: + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + model.enable_input_require_grads() + + # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend + # a space before special tokens. + tokenizer = transformers.T5Tokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + tokenizer=tokenizer, + model=model, + ) + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + # check if zero3 mode enabled + if deepspeed.is_deepspeed_zero3_enabled(): + # use deepspeed engine internal function to gather state dict + # state_dict_zero3 contains whole parameters of base and lora adapters + # we will not extract lora parameters since peft save_pretrained will do that + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 + state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + if training_args.local_rank == 0: + state_dict = state_dict_zero3 + else: + # in other mode we use original code from fastchat team, to make sure our change is minimum + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), lora_args.lora_bias + ) + + if training_args.local_rank == 0: + safe_save_model_for_hf_trainer( + trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict + ) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_mem.py b/fastchat/train/train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce4913aae3ef2080470161724a4f7127abb11f0 --- /dev/null +++ b/fastchat/train/train_mem.py @@ -0,0 +1,13 @@ +# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. + +# Need to call this before importing transformers. +from fastchat.train.llama2_flash_attn_monkey_patch import ( + replace_llama_attn_with_flash_attn, +) + +replace_llama_attn_with_flash_attn() + +from fastchat.train.train import train + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_with_template.py b/fastchat/train/train_with_template.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c5f353de1b14a64150faa703165f52878d309d --- /dev/null +++ b/fastchat/train/train_with_template.py @@ -0,0 +1,400 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import jsonlines +import pathlib +from multiprocessing import Pool +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def apply_prompt_template(sources, template_id, systems=None): + conv = get_conversation_template(template_id) + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + if systems and systems[i]: + conv.set_system_message(systems[i]) + prompt = conv.get_prompt() + conversations.append(prompt) + return conversations, conv + + +def tokenize_conversations(conversations, tokenizer): + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + return input_ids, targets + + +def get_prompt_separator(conv): + if conv.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + user_turn_separator = conv.sep2 + assistant_turn_separator = conv.roles[1] + ": " + + elif conv.sep_style == SeparatorStyle.ADD_COLON_TWO: + user_turn_separator = conv.sep2 + assistant_turn_separator = conv.roles[1] + ": " + + elif conv.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + if conv.sep2 is None: + user_turn_separator = conv.roles[0] + ": " + else: + user_turn_separator = conv.sep2 + + assistant_turn_separator = conv.roles[1] + ": " + + elif conv.sep_style == SeparatorStyle.LLAMA2: + user_turn_separator = conv.sep2 + assistant_turn_separator = conv.roles[1] + " " + + elif conv.sep_style == SeparatorStyle.CHATML: + if conv.sep2 is None: + user_turn_separator = conv.sep + "\n" + else: + user_turn_separator = conv.sep2 + "\n" + + assistant_turn_separator = conv.roles[1] + "\n" + + return user_turn_separator, assistant_turn_separator + + +def mask_targets(conversations, targets, tokenizer, conv): + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + if tokenizer.eos_token is None: + cur_len = 0 + elif tokenizer.eos_token is not None and target[0] != tokenizer.bos_token_id: + cur_len = 0 + elif tokenizer.eos_token is not None and target[0] == tokenizer.bos_token_id: + cur_len = 1 + + target[:cur_len] = IGNORE_TOKEN_ID + user_turn_separator, assistant_turn_separator = get_prompt_separator(conv) + turns = conversation.split(user_turn_separator) + for i, turn in enumerate(turns): + if ( + i < len(turns) - 1 and turn == "" + ): # Last turn is the user_turn_separator + break + + if i != 0: + turn = user_turn_separator + turn + + turn_len = len(tokenizer(turn, add_special_tokens=False).input_ids) + + if assistant_turn_separator in turn: + parts = turn.rsplit(assistant_turn_separator) + parts[0] += assistant_turn_separator + else: + parts = [turn] + + instruction_len = len( + tokenizer(parts[0], add_special_tokens=False).input_ids + ) + + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return targets + + +def preprocess( + sources, tokenizer: transformers.PreTrainedTokenizer, template_id, **kwargs +) -> Dict: + systems = None if not kwargs else kwargs.get("systems", None) + + # If the data volume is small, process it directly in the main thread + if len(sources) <= 1000: + conversations, conv = apply_prompt_template(sources, template_id, systems) + input_ids, targets = tokenize_conversations(conversations, tokenizer) + targets = mask_targets(conversations, targets, tokenizer, conv) + else: # If the data volume is large, use multithreading for processing + with Pool() as p: + conversations, conv = p.apply_async( + apply_prompt_template, (sources, template_id, systems) + ).get() + input_ids, targets = p.apply_async( + tokenize_conversations, (conversations, tokenizer) + ).get() + targets = p.apply_async( + mask_targets, (conversations, targets, tokenizer, conv) + ).get() + p.close() + p.join() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, tokenizer: transformers.PreTrainedTokenizer, template_id + ): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + systems = [example.get("system", "") for example in raw_data] + sources = [example["conversations"] for example in raw_data] + + data_dict = preprocess(sources, tokenizer, template_id, systems=systems) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, tokenizer: transformers.PreTrainedTokenizer, template_id + ): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + self.template_id = template_id + + rank0_print("Formatting inputs...Skip in lazy mode") + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess( + [self.raw_data[i]["conversations"]], + self.tokenizer, + self.template_id, + systems=[self.raw_data[i].get("system", "")], + ) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, + data_args, + template_id, + train_ratio=0.98, +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_ratio = min(train_ratio, 1.0) + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + data_path = data_args.data_path + if data_path.endswith(".json"): + raw_data = json.load(open(data_path, "r")) + elif data_path.endswith(".jsonl"): + with jsonlines.open(data_path, mode="r") as reader: + raw_data = [item for item in reader] + + # Split train/test + np.random.seed(0) + perm = np.random.permutation(len(raw_data)) + split = int(len(perm) * train_ratio) + train_indices = perm[:split] + if train_ratio < 1: + eval_indices = perm[split:] + else: + # if train_ratio==1, we use 5% of data as eval data, make sure trainer will not throw error when eval data is empty + eval_indices = perm[-int(len(perm) * 0.05) :] + train_raw_data = [raw_data[i] for i in train_indices] + eval_raw_data = [raw_data[i] for i in eval_indices] + rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") + + train_dataset = dataset_cls( + train_raw_data, tokenizer=tokenizer, template_id=template_id + ) + eval_dataset = dataset_cls( + eval_raw_data, tokenizer=tokenizer, template_id=template_id + ) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Set RoPE scaling factor + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Tie the weights + model.tie_weights() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + # NOTE: if the token_id exceed the vocab_size will cause failing in training process! we need add special config and resize the embedding size! + tokenizer.pad_token = tokenizer.unk_token + tokenizer.pad_token_id = tokenizer.unk_token_id + print(f"tokens len: {len(tokenizer)}") + model.resize_token_embeddings(len(tokenizer)) + + template_id = model_args.model_name_or_path + data_module = make_supervised_data_module( + tokenizer=tokenizer, + template_id=template_id, + train_ratio=0.98, + data_args=data_args, + ) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_xformers.py b/fastchat/train/train_xformers.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb2badd59140d72ff995ad4419fde2a2a697955 --- /dev/null +++ b/fastchat/train/train_xformers.py @@ -0,0 +1,13 @@ +# Make it more memory efficient by monkey patching the LLaMA model with xformers attention. + +# Need to call this before importing transformers. +from fastchat.train.llama_xformers_attn_monkey_patch import ( + replace_llama_attn_with_xformers_attn, +) + +replace_llama_attn_with_xformers_attn() + +from fastchat.train.train import train + +if __name__ == "__main__": + train() diff --git a/fastchat/train/train_yuan2.py b/fastchat/train/train_yuan2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3c09a14a044f7c085e149b5c6b81202c7d3b6a --- /dev/null +++ b/fastchat/train/train_yuan2.py @@ -0,0 +1,482 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import pathlib +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" + }, + ) + padding_side: str = field( + default="right", metadata={"help": "The padding side in tokenizer"} + ) + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + lazy_preprocess: bool = False + last_response_loss: bool = False + split_example_loss: bool = False + efficient_loss: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def trainer_save_model_safe(trainer: transformers.Trainer): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType, FullStateDictConfig + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + trainer.model, StateDictType.FULL_STATE_DICT, save_policy + ): + trainer.save_model() + + +# add by wpf for yuan test +def right_replace(string, old, new, max=1): + return string[::-1].replace(old[::-1], new[::-1], max)[::-1] + + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, + data_args, +) -> Dict: + conv = get_conversation_template("yuan2") # wpf + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + if data_args.last_response_loss: + a = conversations[0].replace("", "") + a = right_replace(a, "", "") + # a=right_replace(a,"","\n",max=20) + conversations[0] = a + if data_args.split_example_loss: + a = conversations[0].replace("", "") + a = a.split("") + for i in range(int(len(a) / 2)): + if i == 0: + conversations[i] = "" + if i != 0: + conversations.append("") + for j in range(i * 2): + conversations[i] = conversations[i] + a[j] + "" + conversations[i] = ( + conversations[i] + a[i * 2] + "" + a[i * 2 + 1] + "" + ) + + if data_args.efficient_loss: + a = conversations[0].replace("", "") + conversations[0] = a + + print(conversations) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO #wpf + # Mask targets. Only compute loss on the assistant outputs. + # sep = conv.sep + conv.roles[1] + ": " #wpf + + if data_args.split_example_loss: + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + turns = conversation.split("") + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + + for i, turn in enumerate(turns): + if turn == "": + break + if i == 0 or i == len(turns) - 1: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) + 1 + # parts = turn.split(sep) + # if len(parts) != 2: + # break + # parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = 0 + if i == len(turns) - 1: + instruction_len = turn_len + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + # print("cur_len: ", cur_len) + # print("total_len: ", total_len) + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + if data_args.efficient_loss: + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split("") + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + + for i, turn in enumerate(turns): + if turn == "": + break + if i == 0 or i == len(turns) - 1: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) + 1 + # parts = turn.split(sep) + # if len(parts) != 2: + # break + # parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = 0 + if i % 2 == 0: + instruction_len = turn_len + + # if i != 0 and not tokenizer.legacy: + # # The legacy and non-legacy modes handle special tokens differently + # instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + target[cur_len:] = IGNORE_TOKEN_ID + # print("cur_len: ", cur_len) + # print("total_len: ", total_len) + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + if data_args.last_response_loss: + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split("") + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + + for i, turn in enumerate(turns): + if turn == "": + break + if i == 0 or i == len(turns) - 1: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) + 1 + # parts = turn.split(sep) + # if len(parts) != 2: + # break + # parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = 0 + if i == len(turns) - 1: + instruction_len = turn_len + + # if i != 0 and not tokenizer.legacy: + # # The legacy and non-legacy modes handle special tokens differently + # instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + # if i != 0 and not tokenizer.legacy: + # # The legacy and non-legacy modes handle special tokens differently + # cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + # print("cur_len: ", cur_len) + # print("total_len: ", total_len) + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, data_args, tokenizer: transformers.PreTrainedTokenizer + ): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + sources = [example["conversations"] for example in raw_data] + data_dict = preprocess(sources, tokenizer, data_args) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, raw_data, data_args, tokenizer: transformers.PreTrainedTokenizer + ): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.data_args = data_args + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess( + [self.raw_data[i]["conversations"]], self.tokenizer, self.data_args + ) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = dataset_cls(train_json, data_args, tokenizer=tokenizer) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = dataset_cls(eval_json, data_args, tokenizer=tokenizer) + else: + eval_dataset = None + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + + # Set RoPE scaling factor + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + + # Load model and tokenizer + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side=model_args.padding_side, + use_fast=False, + trust_remote_code=model_args.trust_remote_code, + ) + + if tokenizer.pad_token != tokenizer.unk_token: + tokenizer.pad_token = tokenizer.unk_token + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + + # Load data + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + # Start trainner + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + # Save model + model.config.use_cache = True + trainer.save_state() + if trainer.is_deepspeed_enabled: + trainer.save_model() + else: + trainer_save_model_safe(trainer) + + +if __name__ == "__main__": + train() diff --git a/fastchat/utils.py b/fastchat/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..545e014146bfc5bce1ffcde2869f4292a6aec11d --- /dev/null +++ b/fastchat/utils.py @@ -0,0 +1,484 @@ +""" +Common utilities. +""" +from asyncio import AbstractEventLoop +from io import BytesIO +import base64 +import json +import logging +import logging.handlers +import os +import platform +import sys +import time +from typing import AsyncGenerator, Generator +import warnings + +import requests + +from fastchat.constants import LOGDIR + + +handler = None +visited_loggers = set() + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Avoid httpx flooding POST logs + logging.getLogger("httpx").setLevel(logging.WARNING) + + # if LOGDIR is empty, then don't try output log to local file + if LOGDIR != "": + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + + for l in [stdout_logger, stderr_logger, logger]: + if l in visited_loggers: + continue + visited_loggers.add(l) + l.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_gpu_memory(max_gpus=None): + """Get available memory for each GPU.""" + import torch + + gpu_memory = [] + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) + + for gpu_id in range(num_gpus): + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def oai_moderation(text, custom_thresholds=None): + """ + Check whether the text violates OpenAI moderation API. + """ + import openai + + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + # default to true to be conservative + flagged = True + MAX_RETRY = 3 + for _ in range(MAX_RETRY): + try: + res = client.moderations.create(input=text) + flagged = res.results[0].flagged + if custom_thresholds is not None: + for category, threshold in custom_thresholds.items(): + if getattr(res.results[0].category_scores, category) > threshold: + flagged = True + break + except (openai.OpenAIError, KeyError, IndexError) as e: + print(f"MODERATION ERROR: {e}\nInput: {text}") + return flagged + + +def moderation_filter(text, model_list, do_moderation=False): + # Apply moderation for below models + MODEL_KEYWORDS = [ + "claude", + "gpt", + "bard", + "mistral-large", + "command-r", + "dbrx", + "gemini", + "reka", + ] + + custom_thresholds = {"sexual": 0.3} + # set a stricter threshold for claude + for model in model_list: + if "claude" in model: + custom_thresholds = {"sexual": 0.2} + + for keyword in MODEL_KEYWORDS: + for model in model_list: + if keyword in model: + do_moderation = True + break + + if do_moderation: + return oai_moderation(text, custom_thresholds) + return False + + +def clean_flant5_ckpt(ckpt_path): + """ + Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, + Use this function to make sure it can be correctly loaded. + """ + import torch + + index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + index_json = json.load(open(index_file, "r")) + + weightmap = index_json["weight_map"] + + share_weight_file = weightmap["shared.weight"] + share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ + "shared.weight" + ] + + for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: + weight_file = weightmap[weight_name] + weight = torch.load(os.path.join(ckpt_path, weight_file)) + weight[weight_name] = share_weight + torch.save(weight, os.path.join(ckpt_path, weight_file)) + + +def pretty_print_semaphore(semaphore): + """Print a semaphore in better format.""" + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +"""A javascript function to get url parameters for the gradio web server.""" +get_window_url_params_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + return url_params; + } +""" + +get_window_url_params_with_tos_js = """ +function() { + const params = new URLSearchParams(window.location.search); + const url_params = Object.fromEntries(params); + console.log("url_params", url_params); + + const urlContainsLeaderboard = Object.keys(url_params).some(key => key.toLowerCase().includes("leaderboard")); + const msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license."; + if (!urlContainsLeaderboard) { + if (window.alerted_before) return; + alert(msg); + window.alerted_before = true; + } + return url_params; + } +""" + +alert_js = """ +() => { + if (window.alerted_before) return; + const msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nPlease do not upload any private information.\\nThe service collects user dialogue data, including both text and images, and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license."; + alert(msg); + window.alerted_before = true; +} +""" + + +def iter_over_async( + async_gen: AsyncGenerator, event_loop: AbstractEventLoop +) -> Generator: + """ + Convert async generator to sync generator + + :param async_gen: the AsyncGenerator to convert + :param event_loop: the event loop to run on + :returns: Sync generator + """ + ait = async_gen.__aiter__() + + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = event_loop.run_until_complete(get_next()) + if done: + break + yield obj + + +def detect_language(text: str) -> str: + """Detect the langauge of a string.""" + import polyglot # pip3 install polyglot pyicu pycld2 + from polyglot.detect import Detector + from polyglot.detect.base import logger as polyglot_logger + import pycld2 + + polyglot_logger.setLevel("ERROR") + + try: + lang_code = Detector(text).language.name + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + return lang_code + + +def parse_gradio_auth_creds(filename: str): + """Parse a username:password file for gradio authorization.""" + gradio_auth_creds = [] + with open(filename, "r", encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()] + if gradio_auth_creds: + auth = [tuple(cred.split(":")) for cred in gradio_auth_creds] + else: + auth = None + return auth + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +def run_cmd(cmd: str): + """Run a bash command.""" + print(cmd) + return os.system(cmd) + + +def is_sentence_complete(output: str): + """Check whether the output is a complete sentence.""" + end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") + return output.endswith(end_symbols) + + +# Models don't use the same configuration key for determining the maximum +# sequence length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +SEQUENCE_LENGTH_KEYS = [ + "max_position_embeddings", + "max_sequence_length", + "seq_length", + "max_seq_len", + "model_max_length", +] + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling: + rope_scaling_factor = config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + for key in SEQUENCE_LENGTH_KEYS: + val = getattr(config, key, None) + if val is not None: + return int(rope_scaling_factor * val) + return 2048 + + +def str_to_torch_dtype(dtype: str): + import torch + + if dtype is None: + return None + elif dtype == "float32": + return torch.float32 + elif dtype == "float16": + return torch.float16 + elif dtype == "bfloat16": + return torch.bfloat16 + else: + raise ValueError(f"Unrecognized dtype: {dtype}") + + +def load_image(image_file): + from PIL import Image + import requests + + image = None + + if image_file.startswith("http://") or image_file.startswith("https://"): + timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) + response = requests.get(image_file, timeout=timeout) + image = Image.open(BytesIO(response.content)) + elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): + image = Image.open(image_file) + elif image_file.startswith("data:"): + image_file = image_file.split(",")[1] + image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + image = Image.open(BytesIO(base64.b64decode(image_file))) + + return image + + +def upload_image_file_to_gcs(image, filename): + from google.cloud import storage + import io + + storage_client = storage.Client() + # upload file to GCS + bucket = storage_client.get_bucket("arena_service_data") + + blob = bucket.blob(f"{filename}") + if not blob.exists(): + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + blob.upload_from_file(buffer, content_type="image/png") + + return blob.public_url + + +def get_image_file_from_gcs(filename): + from google.cloud import storage + + storage_client = storage.Client() + bucket = storage_client.get_bucket("arena_service_data") + blob = bucket.blob(f"{filename}") + contents = blob.download_as_bytes() + + return contents + + +def image_moderation_request(image_bytes, endpoint, api_key): + headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key} + + MAX_RETRIES = 3 + for _ in range(MAX_RETRIES): + response = requests.post(endpoint, headers=headers, data=image_bytes).json() + try: + if response["Status"]["Code"] == 3000: + break + except: + time.sleep(0.5) + return response + + +def image_moderation_provider(image, api_type): + if api_type == "nsfw": + endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"] + api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"] + response = image_moderation_request(image, endpoint, api_key) + print(response) + return response["IsImageAdultClassified"] + elif api_type == "csam": + endpoint = ( + "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false" + ) + api_key = os.environ["PHOTODNA_API_KEY"] + response = image_moderation_request(image, endpoint, api_key) + return response["IsMatch"] + + +def image_moderation_filter(image): + print(f"moderating image") + + image_bytes = base64.b64decode(image.base64_str) + + nsfw_flagged = image_moderation_provider(image_bytes, "nsfw") + csam_flagged = False + + if nsfw_flagged: + csam_flagged = image_moderation_provider(image_bytes, "csam") + + return nsfw_flagged, csam_flagged diff --git a/format.sh b/format.sh new file mode 100644 index 0000000000000000000000000000000000000000..798b698641772e7ec997a7f026e660a1c6282a13 --- /dev/null +++ b/format.sh @@ -0,0 +1,71 @@ +#!/usr/bin/env bash + +# Adapted from https://github.com/skypilot-org/skypilot/blob/master/format.sh + +# Cause the script to exit if a single command fails +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +BLACK_VERSION=$(black --version | head -n 1 | awk '{print $2}') +PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') + +# # params: tool name, tool version, required version +tool_version_check() { + if [[ $2 != $3 ]]; then + echo "Wrong $1 version installed: $3 is required, not $2." + exit 1 + fi +} + +tool_version_check "black" $BLACK_VERSION "23.3.0" +tool_version_check "pylint" $PYLINT_VERSION "2.8.2" + +# Format files that differ from main branch. Ignores dirs that are not slated +# for autoformat yet. +format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause yapf to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 black + fi +} + +## This flag formats individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + black "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is formatted. +elif [[ "$1" == '--all' ]]; then + # Format all files + black fastchat +else + # Format only the files that changed in last commit. + format_changed +fi +echo 'FastChat Black: Done' + +# Run Pylint +echo 'FastChat Pylint:' +pylint fastchat +# TODO(suquark): disable 'pylint_quotes' for now due to too many inconsistent quotes +# pylint --load-plugins pylint_quotes fastchat + +if ! git diff --quiet &>/dev/null; then + echo 'Reformatted files. Please review and stage the changes.' + echo 'Changes not staged for commit:' + echo + git --no-pager diff --name-only + + exit 1 +fi diff --git a/playground/FastChat_API_GoogleColab.ipynb b/playground/FastChat_API_GoogleColab.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f9fab15020691b489b0acbaf91b8fc5608945b36 --- /dev/null +++ b/playground/FastChat_API_GoogleColab.ipynb @@ -0,0 +1,347 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1UDur96B5C7T" + }, + "source": [ + "# FastChat API using Google Colab\n", + "\n", + "[ggcr](https://github.com/ggcr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NQWpzwse8PrC" + }, + "outputs": [], + "source": [ + "%cd /content/\n", + "\n", + "# clone FastChat\n", + "!git clone https://github.com/lm-sys/FastChat.git\n", + "\n", + "# install dependencies\n", + "%cd FastChat\n", + "!python3 -m pip install -e \".[model_worker,webui]\" --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "97181RzwSjha" + }, + "source": [ + "See [openai_api.md](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) from FastChat docs.\n", + "\n", + "Because in Google Colab we are limited in resources and executing things in the background is not stable, we will run each API process in a thread and communicate them via explicit addresses:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BrhPP9ZggVL0", + "outputId": "be510360-21ba-4f6f-d6b6-24c710bdff4d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/content\n" + ] + } + ], + "source": [ + "import subprocess\n", + "import threading\n", + "\n", + "%cd /content/\n", + "\n", + "# Using 127.0.0.1 because localhost does not work properly in Colab\n", + "\n", + "def run_controller():\n", + " subprocess.run([\"python3\", \"-m\", \"fastchat.serve.controller\", \"--host\", \"127.0.0.1\"])\n", + "\n", + "def run_model_worker():\n", + " subprocess.run([\"python3\", \"-m\", \"fastchat.serve.model_worker\", \"--host\", \"127.0.0.1\", \"--controller-address\", \"http://127.0.0.1:21001\", \"--model-path\", \"lmsys/vicuna-7b-v1.5\", \"--load-8bit\"])\n", + "\n", + "def run_api_server():\n", + " subprocess.run([\"python3\", \"-m\", \"fastchat.serve.openai_api_server\", \"--host\", \"127.0.0.1\", \"--controller-address\", \"http://127.0.0.1:21001\", \"--port\", \"8000\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "3S8vDHy3gWUv" + }, + "outputs": [], + "source": [ + "# Start controller thread\n", + "# see `controller.log` on the local storage provided by Colab\n", + "controller_thread = threading.Thread(target=run_controller)\n", + "controller_thread.start()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "UAU097ymgbNf" + }, + "outputs": [], + "source": [ + "# Start model worker thread\n", + "\n", + "# see `controller.log` on the local storage provided by Colab\n", + "# important to wait until the checkpoint shards are fully downloaded\n", + "model_worker_thread = threading.Thread(target=run_model_worker)\n", + "model_worker_thread.start()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "bTqHMMr1gcQJ" + }, + "outputs": [], + "source": [ + "# Start API server thread\n", + "api_server_thread = threading.Thread(target=run_api_server)\n", + "api_server_thread.start()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iBdjt9I6fuSn" + }, + "source": [ + "We now have the API running at http://127.0.0.1:8000/v1/ locally from Google Colab.\n", + "\n", + "We can run the examples from FastChat with curl." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KtaxADXqhazs" + }, + "source": [ + "Try chat completion with" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MZGd4y2SfBTT", + "outputId": "066835bb-f7f0-4e16-f54a-2f74b0e2f9d9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"id\":\"chatcmpl-3RViU5mrsEBNu8oSxexAEb\",\"object\":\"chat.completion\",\"created\":1705781842,\"model\":\"vicuna-7b-v1.5\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Sure thing! Here's one for you:\\n\\nWhy did the tomato turn red?\\n\\nBecause it saw the salad dressing!\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":50,\"total_tokens\":82,\"completion_tokens\":32}}" + ] + } + ], + "source": [ + "!curl http://127.0.0.1:8000/v1/chat/completions \\\n", + " -H \"Content-Type: application/json\" \\\n", + " -d '{ \\\n", + " \"model\": \"vicuna-7b-v1.5\", \\\n", + " \"messages\": [{\"role\": \"user\", \"content\": \"Hello, can you tell me a joke for me?\"}], \\\n", + " \"temperature\": 0.5 \\\n", + " }'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "umgVIilThc6a" + }, + "source": [ + "Try embeddings with" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VraqDkMahfAQ", + "outputId": "18710c2c-1994-4f36-eff1-6aff5a2a83a4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"embedding\":[0.0229715034365654,-0.020740192383527756,0.01663232035934925,0.013713006861507893,-0.01602417416870594,-0.006382038351148367,0.011642662808299065,-0.021167458966374397,0.004879815969616175,-0.005442662630230188,0.0034834356047213078,-0.010336925275623798,-0.009551243856549263,0.0005828586872667074,-0.0089940270408988,-0.0018360239919275045,-0.021827373653650284,0.007349758874624968,-0.0011765437666326761,-0.01432803925126791,0.012239773757755756,-0.018455859273672104,0.016475312411785126,-0.006144467741250992,-0.013893244788050652,-0.00961716752499342,0.00827623251825571,0.0013034207513555884,0.006355977617204189,0.007773293182253838,0.0029199880082160234,-0.014487813226878643,-0.01615595631301403,0.007242684718221426,-0.004686516709625721,-0.0034376305993646383,-0.0046915397979319096,0.0007899928605183959,-0.003679676679894328,-0.022176748141646385,-0.005467468872666359,-0.02236158587038517,0.02086811512708664,0.0029669292271137238,-0.0168694406747818,0.025603512302041054,0.009139388799667358,0.02165624313056469,-0.004472456872463226,0.0006205983809195459,0.0011453271145001054,0.014379195868968964,0.01994524523615837,-0.017613859847187996,0.005462903995066881,0.005702079739421606,-0.021057194098830223,-0.021468186751008034,-0.004666909575462341,-0.007595115341246128,-0.009129735641181469,-0.0161031112074852,0.009293882176280022,0.00953285675495863,-0.0013638428645208478,0.0007091081934049726,0.0018222536891698837,0.020376019179821014,0.01186810340732336,-0.013734177686274052,-0.004418510012328625,-0.006746952421963215,-0.0006970430840738118,-0.006644704379141331,-0.04453064501285553,0.003871878841891885,-0.01059865765273571,-0.024984514340758324,0.011757172644138336,-0.016218630596995354,-0.009141125716269016,-0.004623874556273222,-0.009732221253216267,-0.009169373661279678,-0.006947007961571217,-0.005838882178068161,-0.0068959807977080345,-0.000743469747249037,0.008742589503526688,-0.008120769634842873,-0.018119709566235542,-0.004530956968665123,-0.003916825633496046,0.02495340257883072,0.010598400607705116,0.010666633024811745,0.00679260678589344,-0.009019959717988968,-0.004487940575927496,-0.0026543298736214638,0.00286748050712049,0.012851846404373646,0.0012102456530556083,0.014895712956786156,-0.01030716486275196,0.01633496955037117,0.015731101855635643,-0.009079995565116405,0.016830960288643837,0.00940327625721693,-0.0014347939286381006,0.0207867082208395,0.06265891343355179,0.002649270463734865,-0.007526970934122801,0.004714089445769787,0.006397288292646408,-0.0029612022917717695,-0.0015034123789519072,-0.006392269395291805,-0.012309122830629349,0.0040127672255039215,0.001810954650864005,-0.016414696350693703,-0.019156336784362793,0.0003308420709799975,0.007823580875992775,0.0020239183213561773,-0.0024881847202777863,-0.008919963613152504,-0.01775810308754444,-0.012687149457633495,0.0022407048381865025,-0.009261680766940117,0.006048525683581829,0.00518012186512351,0.0029072873294353485,-7.72168641560711e-06,0.012007351964712143,-0.0004918070626445115,0.0013227892341092229,0.006292788311839104,-0.010167273692786694,-0.009050589054822922,0.008057740516960621,0.006250383332371712,0.014853340573608875,0.02723078615963459,-0.02242557890713215,0.04399850592017174,0.00313431303948164,-0.022166002541780472,0.010024639777839184,0.003234871895983815,0.0030383227858692408,0.012888548895716667,0.01507903728634119,0.00479199830442667,-0.0024831658229231834,0.008515636436641216,0.0005489314789883792,0.004214818123728037,0.006590660661458969,-0.012804229743778706,0.011747709475457668,0.002035082783550024,0.0143223125487566,0.0134012121707201,-0.0008568498305976391,0.0025005715433508158,-0.012422841973602772,0.014866000972688198,0.020239505916833878,-0.0034607010893523693,-0.026886560022830963,-0.0023535056971013546,-0.0037942437920719385,0.013139543123543262,0.004902820568531752,0.008357052691280842,-0.011724174953997135,0.005840683821588755,0.009768190793693066,0.00013014259457122535,0.016845345497131348,-0.006546108052134514,-0.00838533416390419,-0.01408461295068264,-0.0022769987117499113,0.010644538328051567,0.002947496483102441,0.02589692734181881,0.012639564462006092,0.004540625493973494,-0.0176566019654274,-0.010239857248961926,0.01839127205312252,0.0031600680667907,0.011127336882054806,0.0036535318940877914,0.015353705734014511,-0.026527339592576027,-0.008746611885726452,0.01886408030986786,0.00887488853186369,-0.0001859961193986237,0.001222877879627049,0.0065072583965957165,-0.009838716126978397,0.008630175143480301,-0.00633110711351037,0.02635054476559162,-0.005968477576971054,-0.013434287160634995,0.01017901673913002,-0.003692896803840995,-0.005410553887486458,-0.006332104559987783,-0.017778540030121803,-0.017085647210478783,-0.005269246641546488,-0.013628004118800163,-0.0005570553475990891,0.010984581895172596,0.000956009142100811,0.009669160470366478,-0.0019082700600847602,-0.05074448138475418,-0.03876679390668869,0.0011635294649749994,-0.012585809454321861,0.008794615045189857,0.00023998660617507994,-0.00455761281773448,-0.0020947649609297514,0.017387693747878075,0.004844747018069029,0.008267332799732685,0.00747610442340374,0.02141532674431801,-0.02262278087437153,-0.014600872062146664,-0.021727152168750763,0.008812149986624718,0.009474638849496841,0.03191479295492172,-0.019652077928185463,0.01944698765873909,0.017112286761403084,0.015296016819775105,0.014461753889918327,-0.019157931208610535,0.009540014900267124,0.004215397406369448,-0.008012793958187103,0.013523118570446968,-0.009407458826899529,-0.029304828494787216,0.012041181325912476,0.015149015933275223,0.0031983305234462023,-0.0003109185490757227,0.03257888928055763,0.007614033296704292,-0.005175750236958265,-0.002383652376011014,0.006435382179915905,0.006068408954888582,-0.007524268701672554,0.02373131737112999,0.004817254841327667,0.005436067469418049,-0.0059105646796524525,-0.005925316829234362,-6.454042886616662e-05,-0.008412199094891548,-0.00655836658552289,-0.0010680218692868948,-0.004262322559952736,0.0015925978077575564,0.00412611523643136,-0.011034490540623665,0.009839101694524288,0.00415002042427659,-0.007727092131972313,-0.010377302765846252,0.0007711391081102192,-0.009322070516645908,0.0035655524116009474,-0.026301125064492226,-0.006197007372975349,0.0006739745149388909,-0.00818476639688015,-0.02090131863951683,-0.002644758205860853,0.006994722411036491,-0.0016304099699482322,0.01705804094672203,-0.016460495069622993,0.017486274242401123,0.013187418691813946,0.0033816162031143904,0.017844069749116898,-0.017695210874080658,-0.011941025033593178,0.009029353968799114,0.0033719318453222513,-0.009064359590411186,0.012252643704414368,0.0011845449917018414,0.003185839159414172,0.003374891821295023,-0.007335654925554991,0.0029391313437372446,0.000280876352917403,0.0048222895711660385,-0.0003767217858694494,-0.045474909245967865,0.004725527483969927,0.0075803473591804504,0.005909985862672329,0.002949362387880683,-0.0036183823831379414,0.0026071954052895308,-0.005563989747315645,-0.012707033194601536,-0.004933884367346764,-0.016659578308463097,-0.0081319659948349,0.012579865753650665,-0.022291865199804306,-0.018159057945013046,-0.0069056968204677105,-0.00018650286074262112,-0.006835494190454483,0.0006484286277554929,0.005561383906751871,0.0062789213843643665,0.029090696945786476,0.002546998206526041,0.009344656951725483,-0.0038842656649649143,-0.012519339099526405,-0.0025535617023706436,-0.003679415676742792,-0.0033875037916004658,0.003728062380105257,-0.014787501655519009,0.0023771373089402914,0.005443841218948364,-0.00957341119647026,-0.015306569635868073,0.0046866778284311295,-0.016635537147521973,-0.01424899697303772,0.001698320615105331,-0.004534294828772545,0.0066452836617827415,0.010703673586249352,0.004293128848075867,-0.009486992843449116,-0.0031507215462625027,0.01611129753291607,-0.015744132921099663,-0.014641146175563335,0.0026989546604454517,0.01565713621675968,-0.005524931009858847,0.006648661568760872,0.0040243822149932384,-0.00559786893427372,-0.014391486532986164,0.026553215458989143,-0.009266120381653309,0.020683180540800095,0.00994131714105606,0.0026739235036075115,0.0038542025722563267,-0.012158502824604511,-0.010751161724328995,-0.00017412402667105198,-0.017064156010746956,-0.010691382922232151,0.00937278475612402,-0.014700417406857014,-0.005352479871362448,0.012342552654445171,0.009191831573843956,-0.011637836694717407,-0.012737436220049858,0.01105053722858429,0.020749129354953766,0.07297933101654053,0.027850160375237465,-0.005428216885775328,-0.019425511360168457,0.0016134463949128985,-0.007674881722778082,0.004896160680800676,-0.006309020332992077,0.0028925116639584303,-0.016418879851698875,-0.012568380683660507,-0.0076565672643482685,-0.002051394898444414,0.011267355643212795,0.01101701334118843,0.02482358179986477,0.011389358900487423,-0.01589033007621765,0.0005615596892312169,-0.027247965335845947,-0.008588980883359909,0.005675439722836018,0.008922569453716278,-0.003106530988588929,0.00925450585782528,-0.00030810333555564284,-0.002115500858053565,-0.007074093911796808,-0.005927231162786484,-0.017885340377688408,-0.016033342108130455,-0.0049004401080310345,0.006337509956210852,0.01978384517133236,0.001572070992551744,-0.0143946073949337,-0.008655560202896595,-0.0011587677290663123,-2.521412170608528e-05,-0.01082194410264492,0.010964666493237019,-0.011412781663239002,0.008038532920181751,0.006299568805843592,-0.008974144235253334,0.006545931100845337,0.0006125871441327035,0.00486041558906436,0.0042688059620559216,0.0018871801439672709,-0.006763682700693607,0.013578971847891808,-0.0020262349862605333,-0.0024552710819989443,-0.01506423857063055,0.0054992204532027245,0.011333892121911049,-0.007717472035437822,-0.005762179847806692,0.0007979075890034437,0.007761630229651928,-0.00952511839568615,-0.010288495570421219,0.014522014185786247,-0.005318223498761654,0.009297103621065617,0.0038411528803408146,0.012293890118598938,0.004698003176599741,-0.007060967851430178,-0.004558722488582134,-0.003963573835790157,0.016085496172308922,0.015816137194633484,0.0027972774114459753,-0.017336538061499596,0.014937583357095718,0.013450084254145622,0.06357342004776001,-0.009506811387836933,0.007877970114350319,0.007048371247947216,0.011457744054496288,0.023370005190372467,0.014203527010977268,-0.004614254459738731,-0.008159955963492393,0.0030794248450547457,-0.0010602197144180536,0.0006093939300626516,-0.010418003425002098,-0.007668149657547474,0.015968769788742065,-0.0015574641292914748,-0.018846578896045685,-0.003667157609015703,0.0019307112088426948,-0.001895931432954967,-0.010295855812728405,0.00023113582574296743,0.007993489503860474,0.0022910244297236204,0.00033837073715403676,-0.005313453264534473,0.0010675875237211585,-0.01549510844051838,0.007410695310682058,0.009646059945225716,-0.012997191399335861,0.010529725812375546,-0.019208982586860657,-0.010029473342001438,-0.013124711811542511,0.029043130576610565,-0.00493550905957818,0.008303387090563774,0.0067044831812381744,0.005133184138685465,-0.008268092758953571,0.0027517518028616905,-0.013479426503181458,-0.01547516044229269,-0.020013773813843727,-0.006451855413615704,0.008133156225085258,-0.006830539554357529,-0.007085484452545643,0.010033013299107552,0.002104497514665127,0.0005678657325915992,0.006996427197009325,-0.00785919837653637,-0.029753299430012703,0.03372034803032875,-0.008247010409832,0.008989491499960423,0.017457574605941772,-0.0059603373520076275,-0.003432418452575803,-0.014526166021823883,0.01819109544157982,-0.007616993971168995,-0.008361894637346268,0.008198246359825134,0.004229682497680187,-0.02080651931464672,0.009076694026589394,-0.006605580914765596,0.0037523536011576653,-0.010452975519001484,-0.012760377489030361,-0.017025675624608994,-0.007993683218955994,0.013692287728190422,0.017206765711307526,0.006106856279075146,0.011746293865144253,-0.009011680260300636,-0.007511272560805082,0.006244495511054993,0.009395747445523739,0.006921007763594389,0.00926200207322836,0.03370635211467743,0.0026780739426612854,0.012087206356227398,0.0012626887764781713,-0.014491417445242405,-0.007984738796949387,-0.02033303491771221,-0.008010058663785458,-0.0027411666233092546,-0.006356299389153719,0.014341359958052635,0.00924749206751585,0.008061794564127922,-0.014423820190131664,-0.0027274927124381065,-0.009464149363338947,0.0032869288697838783,0.028920968994498253,-0.007417581044137478,-0.012927711941301823,-0.006823397241532803,0.0021555088460445404,-0.008643687702715397,-0.0023652170784771442,-0.0060961428098380566,-0.017238536849617958,-0.007533663418143988,0.0022437411826103926,-0.0029654495883733034,0.007918364368379116,-0.004272923804819584,0.022094689309597015,-0.01293826475739479,-0.03929437696933746,-0.05735565349459648,-0.013071688823401928,0.0007404614589177072,-0.000720368989277631,-0.006117763463407755,-0.011282929219305515,0.010729444213211536,-0.014913954772055149,0.00311655318364501,0.006948134861886501,-0.00748022273182869,-0.02309916727244854,-0.0178229883313179,-0.0072602517902851105,0.007839913479983807,0.012868576683104038,0.002075975527986884,0.0007498079212382436,0.005815781187266111,-0.011992518790066242,0.010061261244118214,0.004755143541842699,-0.0014543153811246157,0.014422083273530006,-0.0023919050581753254,0.009424189105629921,-0.01841503195464611,0.008597759529948235,0.023288220167160034,-0.009507520124316216,0.015740947797894478,-0.0004225693119224161,0.02476677857339382,-0.011370633728802204,0.011675688438117504,0.020527847111225128,-0.0073279449716210365,-0.013483609072864056,-0.019474929198622704,-0.004016772843897343,-0.012925073504447937,-0.00565439835190773,0.0104595385491848,-0.007314899004995823,0.010194428265094757,0.0022050561383366585,0.011519340798258781,-0.0059105330146849155,-0.0007297637057490647,-0.016200484707951546,0.015271657146513462,-0.016203250735998154,0.034517351537942886,0.0006107089575380087,-0.013269267976284027,0.01328535471111536,-0.02016814425587654,-0.007773164194077253,-0.007333156652748585,-0.01815428026020527,0.006929537747055292,-0.0034732790663838387,-0.004669690039008856,0.0016878641908988357,-0.03094855323433876,0.0019403311889618635,-0.005923015996813774,-0.0040122526697814465,0.009299001656472683,-0.006708343978971243,0.01585310511291027,0.0010694535449147224,0.0006908577051945031,-0.0015497022541239858,-0.014749257825314999,0.013069666922092438,-0.0003381777205504477,-0.0186776015907526,-0.00869465060532093,-0.005246113985776901,0.004712183494120836,-0.0033125269692391157,0.005922533571720123,0.005009307526051998,-0.002772809471935034,0.0018297180067747831,-0.007289668545126915,-0.025313491001725197,-0.010890730656683445,-0.013207301497459412,-0.015217771753668785,-0.0064299451187253,0.0012019408168271184,0.013148745521903038,-0.022279510274529457,0.008878774009644985,-0.007133841048926115,-0.0007347667124122381,0.007130189798772335,0.0017936835065484047,0.012268022634088993,0.007812416646629572,0.009994118474423885,-0.01274168398231268,-0.000458410766441375,-0.006630516145378351,0.0004267197218723595,0.013977475464344025,-0.003951766062527895,-0.0167144313454628,-0.012754247523844242,0.012914633378386497,0.010781855322420597,0.002908888040110469,-0.007131427992135286,0.017916306853294373,-0.005879903212189674,-0.002502115909010172,-0.0016746085602790117,-0.024386180564761162,-0.008716223761439323,0.003937223460525274,0.004685036838054657,-0.005052074324339628,-0.004745748359709978,-0.004316418897360563,-0.009056701324880123,-0.011055074632167816,0.0087593924254179,-0.016003968194127083,-0.001959120621904731,0.014024545438587666,-0.005205253139138222,-0.0034684527199715376,-0.00704217841848731,0.004913646727800369,0.01903299242258072,-0.007594246882945299,-0.0001278904383070767,-0.00024535658303648233,0.01912636123597622,0.02121288888156414,0.01097018364816904,-0.005211591720581055,-0.004693691153079271,0.0002123745362041518,0.01864037662744522,0.004567956551909447,-0.006998493801802397,0.002807476557791233,-0.0272210780531168,0.008950882591307163,-0.007628897670656443,0.017757385969161987,0.011070613749325275,-0.02169198729097843,0.005343310534954071,0.0013322805752977729,-0.004593148827552795,0.009079867042601109,0.011012605391442776,0.00658367108553648,-0.004797258879989386,-0.006833371240645647,-0.0069283475168049335,-0.009916930459439754,-0.006784595549106598,-0.03476946800947189,0.020896492525935173,0.008564138785004616,-0.0012716330820694566,-0.013008822686970234,-0.000613439769949764,0.0047750589437782764,-0.012346075847744942,0.006973704788833857,-0.013979197479784489,-0.006083691958338022,0.005035505164414644,0.011145804077386856,0.013424682430922985,-0.00019037174934055656,-0.008892635814845562,-0.01950671710073948,-0.010386078618466854,0.01175111997872591,-0.014368708245456219,0.00041413979488424957,-0.014867283403873444,0.0020979661494493484,-0.002743129152804613,0.004309915471822023,-0.012755325064063072,0.013642803765833378,0.008863402530550957,0.0013711462961509824,-0.019572222605347633,0.0036479418631643057,0.1259939968585968,0.01384377758949995,0.015267448499798775,0.014036224223673344,0.0038570465985685587,-0.005283885635435581,0.010237026028335094,-0.011374881491065025,-0.011878897435963154,-0.008971023373305798,-0.009165126830339432,-0.0010226268786936998,-0.007337307557463646,-0.010756309144198895,-0.014150279574096203,0.002133630681782961,-0.015334047377109528,0.00481215538457036,-0.013047880493104458,-0.014511879533529282,-0.0030851999763399363,-0.007749861106276512,-0.006487664300948381,0.013752967119216919,-0.012187069281935692,0.0007167012081481516,-0.0016341822920367122,-0.004467220976948738,0.0042928713373839855,0.022611349821090698,0.0005482397391460836,-0.017850179225206375,-0.014368931762874126,-0.02213916927576065,0.009322037920355797,-0.008927379734814167,0.0012655361788347363,0.003878731979057193,-0.011226431466639042,0.014120969921350479,-0.013007482513785362,-0.027299636974930763,-0.02149207703769207,0.0018350587924942374,0.0014142269501462579,-0.000801382411736995,0.010266175493597984,0.006652788259088993,0.0005369306891225278,-0.006750426720827818,0.0077108764089643955,0.008079683408141136,-0.0018402388086542487,-0.016589464619755745,-0.009489567019045353,-0.006460928358137608,-0.008930034004151821,0.005508729722350836,-0.021854624152183533,0.0021647908724844456,-4.1697108827065676e-05,0.0023772178683429956,-0.015694361180067062,-0.0025681040715426207,0.02343827858567238,-0.007234286982566118,0.011763988994061947,0.006332748103886843,0.01569999009370804,0.0011719107860699296,-0.0026809938717633486,-0.019673682749271393,0.010832150466740131,0.0020819918718189,0.0021434274967759848,0.014149283058941364,-0.018654564395546913,-0.005904508288949728,0.024274280294775963,0.0020302003249526024,0.009324193932116032,-0.0019528145203366876,0.010275795124471188,-0.007945165038108826,0.02523057907819748,-0.0015196279855445027,-0.0033202609047293663,-0.00838176254183054,0.009073046036064625,0.004423896782100201,0.0025238976813852787,0.0009007186163216829,0.012340654619038105,0.013026034459471703,0.0006704675615765154,-0.011622972786426544,0.0025514704175293446,0.0018054911633953452,-0.00021421245764940977,0.0015564989298582077,0.0002535287057980895,-0.007833908312022686,-0.002614386146888137,0.010472987778484821,0.008430087007582188,-0.010319744236767292,-0.007126948330551386,-0.0032228068448603153,-0.005715849809348583,-0.007379905320703983,0.0007485531968995929,-0.020927315577864647,0.0019611797761172056,0.0038484123069792986,-0.006966795306652784,-0.018788157030940056,0.007531090173870325,-0.006524322554469109,0.010099516250193119,-0.004077629651874304,-0.017544057220220566,-0.0056204223074018955,0.0014705952489748597,0.02655109204351902,-0.004098542500287294,0.00679929880425334,-0.009616298601031303,-0.00428798096254468,-0.004214432090520859,0.017463093623518944,0.007254500407725573,0.011614413931965828,-0.015450838021934032,0.01430854294449091,0.011353002861142159,0.0038417966570705175,0.013071335852146149,-0.003091377206146717,-0.0012477281270548701,-0.012130544520914555,-0.0005112078506499529,0.0007805016357451677,0.01115238294005394,-0.011903454549610615,0.01652473211288452,-0.016062499955296516,0.0243363119661808,0.00521033676341176,-0.019244149327278137,0.015055154450237751,-0.0014579187845811248,0.024649038910865784,0.003033657558262348,-0.004459853284060955,-0.0024275374598801136,-0.004720765631645918,-0.008315999060869217,0.01299308892339468,0.003514010924845934,0.00035230195499025285,-0.0016822096658870578,-0.011835559271275997,0.013584377244114876,0.014042497612535954,-0.0021746200509369373,-0.013556176796555519,0.009201740846037865,-0.016880186274647713,0.006788729690015316,0.007318035699427128,0.0079000573605299,-0.0021131120156496763,0.005459972191601992,-0.01956108957529068,-0.003485738066956401,-0.012780935503542423,-0.010953888297080994,-0.0035778111778199673,0.013985262252390385,0.004123058635741472,-0.017365043982863426,0.02569989673793316,-0.0032679142896085978,-0.006953733041882515,-0.020901406183838844,0.003745210822671652,0.004216748755425215,0.007281791884452105,0.01097949780523777,-0.008859830908477306,0.0076435767114162445,-0.002383668441325426,0.003228791058063507,0.000471006816951558,0.021136121824383736,0.006612015888094902,-0.00790025107562542,0.002388188848271966,-0.01046378631144762,0.0019024949287995696,-0.020805569365620613,0.008167678490281105,0.01708216592669487,0.003778784302994609,-0.007486400194466114,0.009304165840148926,0.01634320802986622,-0.015319439582526684,0.012349807657301426,0.008008498698472977,0.004085544031113386,-0.0019550668075680733,-0.0013337925774976611,0.005621806252747774,0.00999923050403595,0.0067540789023041725,0.024973737075924873,-0.013562659732997417,-0.009736709296703339,-0.012089909054338932,-0.016808679327368736,0.008086872287094593,0.008295665495097637,-0.012549092061817646,-0.010748330503702164,3.521411053952761e-05,0.0017467420548200607,0.01626216247677803,0.009219243191182613,-0.006609965115785599,0.010143030434846878,-0.020303402096033096,-0.01044105552136898,-0.013318654149770737,0.00010932621080428362,0.007084518671035767,0.007645950186997652,-0.0032920767553150654,-0.01955648884177208,0.0074850814417004585,0.00894773006439209,0.009001234546303749,0.005829519592225552,-0.0045957546681165695,0.0013910618145018816,-0.012523948214948177,0.013304369524121284,0.01453658938407898,0.017666004598140717,-0.004940214566886425,-0.011730528436601162,-0.015571167692542076,-0.010929387994110584,-0.0006716740899719298,0.02221648395061493,0.021565254777669907,0.01226515881717205,-0.0053292508237063885,0.0007020622142590582,0.0024210221599787474,0.01962619461119175,-0.004420963115990162,-0.015309896320104599,0.0034791347570717335,0.02059043198823929,-0.008116353303194046,-0.0032520205713808537,-0.012169377878308296,0.025940747931599617,-9.851584763964638e-05,0.0036511996295303106,0.0037823636084795,-0.010169846937060356,0.010504196397960186,0.013252376578748226,-0.007866725325584412,-0.0026977320667356253,-0.011583752930164337,-0.006372353993356228,-0.0007445314549840987,-0.0030074622482061386,0.016342146322131157,-0.009066401980817318,0.0021215977612882853,0.008862188085913658,0.015515057370066643,0.009001555852591991,-0.024249698966741562,0.020413951948285103,0.008854007348418236,0.0006535120774060488,0.013391399756073952,-0.01817990653216839,-0.0016513630980625749,-0.011816883459687233,0.007374065928161144,0.02026175521314144,-0.019211476668715477,0.00015504502516705543,-0.007945390418171883,0.001324703567661345,0.025466380640864372,0.006762733682990074,-0.01408602949231863,-0.01516133826225996,-0.0069986796006560326,-0.0004754628462251276,-0.01119284238666296,-0.004222266376018524,-0.014954396523535252,0.0031823322642594576,-0.009523541666567326,-0.011928976513445377,-0.0011272374540567398,-0.009063232690095901,-0.011843233369290829,-0.0030050550121814013,-0.010779651813209057,0.017810650169849396,0.009822757914662361,-0.0130256162956357,-0.002755612600594759,0.010061550885438919,-0.002134740585461259,-0.0004929009592160583,-0.011506262235343456,0.004393350332975388,0.002644677646458149,0.013704448938369751,-0.015646131709218025,-0.005174269899725914,0.017940374091267586,0.006815964821726084,-0.014483116567134857,-0.018775692209601402,-0.017056433483958244,-0.00333380582742393,-0.01628420129418373,-0.02220962941646576,-0.007394126150757074,0.004732364322990179,0.003667865414172411,0.013815898448228836,-0.014784134924411774,0.006790837273001671,-0.005050111562013626,-0.01184664387255907,-0.005963458679616451,0.01068057306110859,0.01837034337222576,6.692128226859495e-05,-0.0020520382095128298,-0.005477442871779203,0.008534909226000309,0.021816853433847427,0.019038107246160507,0.008523069322109222,-0.021777216345071793,-0.01595551334321499,-0.012562203221023083,0.012347427196800709,0.013057525269687176,-0.015681490302085876,0.012324455194175243,-0.0041071330197155476,0.01061281468719244,-0.01118357665836811,-0.001830828026868403,0.0030818136874586344,0.0002257306332467124,0.012498816475272179,0.005094640422612429,0.020110618323087692,0.008550223894417286,0.008692882023751736,0.0034023199696093798,-0.0035538740921765566,0.017047973349690437,-0.008395790122449398,0.0036361422389745712,0.0012567044468596578,-0.012467821128666401,0.015781357884407043,-0.009986070916056633,0.01078745350241661,0.008992418646812439,-0.00894157588481903,-0.009751653298735619,-0.007818657904863358,-0.11352294683456421,0.006673813331872225,0.0006858144770376384,0.012712855823338032,0.017139634117484093,-0.003267174120992422,-0.0037179840728640556,-0.027594735845923424,0.015738407149910927,-0.008096124045550823,0.008535375818610191,-0.006178006995469332,0.0021386174485087395,0.00922358687967062,0.015902427956461906,0.010610240511596203,-0.006293817888945341,0.007873225025832653,-0.009341374039649963,-0.015121137723326683,-0.0025967389810830355,0.0009708734578453004,0.02104487642645836,-0.0034994683228433132,-0.012507845647633076,0.022736024111509323,-0.007137798238545656,0.004183493088930845,-0.005087561905384064,0.005540612153708935,0.011934671550989151,-0.008175094611942768,0.013157593086361885,0.003565874882042408,0.007175907958298922,0.02075435034930706,-0.008561364375054836,0.0018133737612515688,-0.0031988373957574368,0.0026560029946267605,-0.015025373548269272,0.0025075653102248907,-0.020946715027093887,0.002409552223980427,0.0030347283463925123,-0.008436071686446667,0.011734389699995518,0.005770737770944834,0.0027340843807905912,0.009276704862713814,0.014263113029301167,0.005924335680902004,-0.013983492739498615,-0.0073938933201134205,-0.0037190215662121773,-0.007606761995702982,0.00866461731493473,-0.00787393283098936,0.004571785684674978,-0.01736222766339779,0.0011665115598589182,-0.018963271751999855,0.002434736117720604,0.023223616182804108,0.013454395346343517,-0.007077569141983986,0.006989220157265663,0.0016794668044894934,-0.0029226583428680897,0.015770161524415016,-0.007460178807377815,0.02135499194264412,-0.0067621381022036076,0.006347097456455231,0.01143655739724636,-0.009779580868780613,0.0011012412142008543,0.022937849164009094,0.03317839652299881,0.002777715912088752,0.0014309572288766503,-0.004011448472738266,-0.020232975482940674,-0.0036248492542654276,0.009381849318742752,-0.004546706099063158,0.01232175249606371,-0.02003932185471058,0.005393791012465954,0.007975440472364426,-0.02001962997019291,0.00812353566288948,0.004558304324746132,0.012361841276288033,-0.00022309240011963993,-0.005494816228747368,-0.005414157174527645,-0.0007955267792567611,-0.006178250070661306,0.0011265840148553252,0.014568240381777287,-0.015398587100207806,-0.009784664027392864,0.002724339719861746,-0.012673153541982174,-0.0022227196022868156,0.012834923341870308,0.011582594364881516,0.0023665439803153276,0.006087005604058504,-0.0014784777304157615,0.004853080026805401,0.004227772355079651,0.005455693230032921,-0.0038181168492883444,-0.009257722645998001,0.006031699012964964,0.0033167218789458275,-0.0009175615850836039,0.023257719352841377,-0.0028650029562413692,0.002901359461247921,0.002793062711134553,0.01102980226278305,0.0026135335210710764,0.028918616473674774,0.015613989904522896,-0.0029948721639811993,-0.009738076478242874,0.018055813387036324,0.0043314797803759575,0.008178786374628544,-0.011788956820964813,0.011455508880317211,0.01573013886809349,0.00820583663880825,0.01591729186475277,0.002678733319044113,-0.017613554373383522,-0.00441357959061861,-0.010343971662223339,0.003275121096521616,-0.004354435950517654,-0.016168376430869102,-0.016327762976288795,0.010710583068430424,-0.0002415279159322381,-0.005174752790480852,-0.010321610607206821,2.5521783754811622e-05,-0.005093996413052082,0.00427284324541688,-0.00925386231392622,-0.022916292771697044,-0.005452363286167383,-0.005463994108140469,-0.00032996939262375236,-0.0056364452466368675,-0.01507771946489811,-0.0140626709908247,-0.001988076837733388,0.010080339387059212,-0.008691756054759026,0.001160038635134697,-0.0021076020784676075,-0.012562798336148262,-0.002622719155624509,0.0030087551567703485,-0.007625970058143139,-0.002947271103039384,0.018139785155653954,0.02823634259402752,-0.0030986485071480274,-0.0026572253555059433,-0.009556874632835388,-0.0120854452252388,-0.016098687425255775,0.004706657491624355,0.018779207020998,-0.008696485310792923,0.02307201363146305,0.008763439022004604,-0.014935833401978016,-0.010818082839250565,-0.2784213721752167,-0.007361662574112415,-0.009495736099779606,-0.023461056873202324,-0.008934522047638893,0.015963122248649597,0.0016804963815957308,-0.009592200629413128,-0.011385498568415642,0.010840379633009434,0.0007005499792285264,0.0030378401279449463,0.01442185789346695,0.0060276128351688385,0.011916878633201122,0.0019495971500873566,0.010881658643484116,0.010174351744353771,0.002560058841481805,-0.011619336903095245,0.005709640681743622,-0.019679618999361992,0.008580016903579235,-0.020601846277713776,-0.003206663765013218,-0.009325030259788036,0.010211093351244926,0.02160986326634884,-0.0012345046270638704,-0.0058813090436160564,0.02697822079062462,-0.009422902949154377,-0.013682184740900993,-0.0015802914276719093,0.020953504368662834,-0.003903919830918312,-0.00243631680496037,-0.020303402096033096,0.01755078323185444,0.024769868701696396,0.0016339250141754746,0.02251550555229187,0.004645044915378094,-0.010381357744336128,-0.014821520075201988,-0.010959195904433727,0.00934459175914526,-0.010714001022279263,0.018016111105680466,-0.00970667414367199,-0.007309091277420521,-0.012314545921981335,-0.02047012746334076,0.027432451024651527,-0.0009060755837708712,0.07745006680488586,-0.0023823976516723633,0.01124457735568285,0.0096189696341753,-0.0008077527745626867,-0.0035770712420344353,-0.0034886582288891077,0.011778567917644978,-0.008943229913711548,0.003386442083865404,-0.00024284704704768956,0.010145587846636772,0.007330470718443394,0.003942918032407761,0.0022819836158305407,-0.0008272781851701438,0.007588133215904236,0.005243266467005014,-0.014266717247664928,-0.005166773218661547,0.0074570500291883945,-0.0016363218892365694,-0.019104288890957832,-0.005167931783944368,0.008953874930739403,-0.007413430605083704,-0.013545575551688671,-0.017633790150284767,0.026401540264487267,-0.0021100472658872604,-0.010175767354667187,0.009788733907043934,-0.014036711305379868,0.003915506415069103,-0.003761973464861512,-0.004975275602191687,0.002093156334012747,-0.001363328075967729,-0.0029019585344940424,-0.009283140301704407,-0.006503944285213947,-0.011568261310458183,0.02174294926226139,-0.014086995273828506,0.0033965124748647213,0.0035606948658823967,0.003461358603090048,0.010544992983341217,0.010210482403635979,-0.002245498588308692,0.019960559904575348,-0.007419897243380547,-0.007997768931090832,0.00904663186520338,0.02357649616897106,-0.011239221319556236,-0.00011569660273380578,-0.0029487835709005594,0.007448234129697084,0.016541525721549988,-0.0001295312977163121,0.009020346216857433,-0.020686302334070206,0.015325473621487617,-0.0016831347020342946,-0.008773420937359333,0.016255050897598267,-0.0012025240575894713,0.01161193661391735,-0.016618099063634872,0.012996693141758442,-0.004140432924032211,-0.007176905404776335,0.020722240209579468,-0.010730667039752007,0.01690627448260784,-0.0032811376731842756,0.010093660093843937,-0.0027236961759626865,-0.03603730350732803,-0.004680242855101824,0.006091711111366749,-0.012325975112617016,-0.014773516915738583,-0.012536093592643738,0.0029048342257738113,-0.02004828117787838,-0.007857202552258968,-0.012408236041665077,-0.005879549775272608,-0.003138889791443944,-0.015323558822274208,-0.0001826178777264431,0.004041365813463926,-0.015603084117174149,0.008681814186275005,0.01134839653968811,0.0006241817027330399,-0.026418721303343773,0.0036757681518793106,0.0031010936945676804,-0.0018149744719266891,-0.0038577064406126738,-0.010925833135843277,-0.006739520467817783,-0.014096260070800781,-0.005563016515225172,0.016652911901474,-0.0007585270213894546,0.011374784633517265,-0.009055189788341522,0.014467866159975529,0.021866194903850555,-0.011922026984393597,-0.006064226385205984,0.014592982828617096,0.012229286134243011,0.007419169414788485,-0.003800228238105774,0.005821636877954006,0.005980832036584616,0.019860951229929924,0.0005983874434605241,-0.021042626351118088,-0.011280648410320282,-0.0034789254423230886,-0.005904307123273611,0.00940112117677927,-0.01505252718925476,-0.007798091508448124,-0.005041247699409723,-0.020565425977110863,0.002939002588391304,-0.010503344237804413,0.006530262529850006,-0.00948650948703289,0.006920433137565851,-0.013644187711179256,-0.01110368873924017,-0.0007017726311460137,-0.011356927454471588,-0.009044218808412552,0.004168874584138393,0.014494956471025944,0.007382184267044067,-0.01204177737236023,-0.0026305855717509985,0.00237200572155416,-0.011614670976996422,0.0075203352607786655,-0.007654733490198851,-0.018017364665865898,-0.007952709682285786,0.009685106575489044,0.016591427847743034,0.008159216493368149,-0.004515109583735466,0.019129447638988495,-0.1756141632795334,-0.024899190291762352,0.0018353804480284452,0.008671293035149574,-0.01384413056075573,0.01001817174255848,-0.012732546776533127,0.005506077315658331,0.0014535110676661134,-0.00014272250700742006,-0.02563503570854664,0.0071355667896568775,-0.02158156782388687,-0.00474808132275939,0.018071835860610008,0.023083724081516266,0.009568641893565655,0.006390306632965803,-0.005066118203103542,-0.01592129096388817,0.017062868922948837,-0.01115796621888876,-0.015767812728881836,-0.005238134413957596,0.006928991060703993,0.006582673639059067,-0.008210115134716034,-0.0006850744248367846,0.003518740413710475,0.02363714389503002,0.014902275986969471,-0.00873962976038456,-0.00457162456586957,0.008439594879746437,0.004671009257435799,0.006651798263192177,0.007029373198747635,0.010178695432841778,-0.01541563868522644,0.005330503452569246,0.005778331309556961,0.010172613896429539,-0.0029294793494045734,-0.005375274922698736,0.015940893441438675,-0.01708410307765007,0.02029111236333847,0.020185356959700584,0.003809751709923148,0.010334190912544727,0.004035063553601503,-0.013017106801271439,-0.009174071252346039,0.0011511747725307941,0.003145364811643958,-0.004294078331440687,0.01332454290241003,-0.013086714781820774,0.016923105344176292,-0.012309269048273563,-0.012259078212082386,0.0015276713529601693,0.00023750621767248958,-0.00841486919671297,-0.012003683485090733,-0.02218620665371418,-0.006810398772358894,-0.05309946462512016,-0.016830896958708763,0.008899983949959278,0.013663781806826591,-0.008498359471559525,-0.009214417077600956,-0.005358291324228048,-0.019415665417909622,-0.0016335167456418276,-0.01287610549479723,-0.005925686564296484,0.007678573951125145,0.004894197918474674,-0.005250392947345972,0.01937422715127468,0.03884986415505409,0.007704956457018852,0.004224277101457119,-0.010258260183036327,0.012103293091058731,0.0007560174562968314,0.009477147832512856,0.005485904403030872,0.011781315319240093,0.005216819699853659,-0.01289766188710928,-0.00058182911016047,-0.006487181875854731,0.010025066323578358,0.01070936769247055,0.008055237121880054,0.009198716841638088,-0.0050565944984555244,0.01677780970931053,-0.004822997841984034,-0.0006103349733166397,-0.010622531175613403,-0.007425166200846434,-0.0016098107444122434,-0.006618257611989975,0.0011639798758551478,-0.08570022881031036,0.020885812118649483,-0.025955354794859886,0.018434884026646614,-0.0073579950258135796,0.005618041846901178,0.005165067967027426,0.0032188494224101305,-0.0012533745029941201,0.015155804343521595,-0.004030752461403608,-0.0077774110250175,0.0008675797143951058,-0.0021942458115518093,0.005814365576952696,0.0067954701371490955,-0.0116463303565979,-0.004899860825389624,0.012563779018819332,-0.02336389385163784,0.0006979600293561816,-0.004649227485060692,-0.012502971105277538,-0.010896007530391216,0.0012360489927232265,-0.012883569113910198,0.025206802412867546,0.011092202737927437,-0.01052560843527317,-0.006687352433800697,-0.01787686161696911,0.004141188692301512,0.0106991371139884,-0.00821922067552805,-0.02622329816222191,0.006792123895138502,-0.013250929303467274,0.007654957938939333,0.008035637438297272,-0.005465570371598005,-0.013763535767793655,-0.01950150541961193,0.008698672987520695,0.0057535613887012005,-0.019228672608733177,-0.011553805321455002,-0.0003967660013586283,0.0012686088448390365,0.006336930673569441,-0.005957281216979027,-0.002579220337793231,-0.002936155302450061,0.0036823435220867395,0.005852008704096079,0.017855370417237282,-0.00011639236618066207,0.0004218293179292232,0.001062761410139501,0.0018936148844659328,0.0179592277854681,0.006386397872120142,0.009569131769239902,0.00946755986660719,0.0031641540117561817,-0.019553659483790398,0.0029401606880128384,-0.014651062898337841,-0.009318306110799313,0.01822330802679062,0.019901007413864136,0.002202707575634122,0.003464141394942999,0.0073665534146130085,-0.014449591748416424,-0.0014002956449985504,0.01639820821583271,0.010666480287909508,0.00931896548718214,-0.0015187592944130301,-0.023576384410262108,-0.00443253805860877,0.014584994874894619,-0.0053917961195111275,0.01415127795189619,0.011401182971894741,-0.0006382536957971752,0.018119532614946365,0.009133468382060528,0.012955060228705406,-0.0014709169045090675,-0.016649436205625534,0.02026389352977276,0.0006713725160807371,0.015495236963033676,0.003925270866602659,0.00319079402834177,-0.003925030119717121,-0.021138904616236687,-0.00461933808401227,-0.005469720810651779,0.00739274313673377,0.019258851185441017,0.02616351842880249,0.023124778643250465,-0.00566488690674305,0.01773357018828392,0.023644834756851196,0.0047590043395757675,0.017013562843203545,-0.0032865749672055244,-0.018152205273509026,-0.010509730316698551,0.004198023583739996,0.011710388585925102,-0.00446705985814333,0.002852680627256632,-0.002007831586524844,-0.000134904301376082,-0.01944751851260662,0.017555125057697296,0.007372296415269375,0.013482901267707348,-0.01416250690817833,0.009404434822499752,0.002286749193444848,0.005182494409382343,-0.0028514256700873375,0.004553719889372587,-0.0026370203122496605,-0.0011353131849318743,0.011851341463625431,-0.00646215071901679,-0.013426951132714748,0.020288217812776566,0.006485862657427788,0.01353476569056511,-0.015545669943094254,0.006692144554108381,0.0026561636477708817,0.0048660943284630775,-0.018292417749762535,-0.007460114546120167,0.022227099165320396,0.0106017105281353,0.05320962518453598,-0.02265460416674614,-0.01131453923881054,0.012853817082941532,-0.0002959979756269604,0.025417005643248558,-0.00955783948302269,0.0014118781546130776,-0.00904284231364727,-0.008947938680648804,-0.007168934214860201,-0.00964303594082594,-0.004022146109491587,-0.005613087210804224,-0.12938329577445984,-0.0043584736995399,0.020456742495298386,0.0071443296037614346,-0.011277008801698685,-0.02349260449409485,-0.010244361124932766,-0.00665429187938571,-0.010064574889838696,0.005249082110822201,0.005279236473143101,0.017985159531235695,-0.02883007377386093,0.010324330069124699,-0.012035149149596691,0.008913593366742134,0.008274752646684647,-0.0018126015784218907,-0.004603218752890825,0.00580825237557292,0.008159039542078972,0.01880655251443386,0.0002549282507970929,-0.004038217011839151,0.005237426608800888,-0.018459560349583626,-0.00046851334627717733,0.0023338748142123222,-0.0042199338786304,-0.006385834887623787,0.011244351975619793,0.0007573044276796281,0.01756402850151062,-0.008600994013249874,-0.0022277063690125942,-0.0030407358426600695,-0.007221739273518324,0.01820104382932186,-0.02493535354733467,0.01585320197045803,-0.0005586881306953728,0.0033721248619258404,-0.00026433906168676913,-0.000743469747249037,0.005868381354957819,0.006111698690801859,-0.0011203524190932512,0.011258958838880062,-0.0008901173714548349,-0.011496561579406261,-0.008037720806896687,0.016194118186831474,0.011407424695789814,-0.014084485359489918,0.017604801803827286,0.002007188042625785,-0.006658796686679125,-0.009705387987196445,0.015173210762441158,0.006459673400968313,-0.00285873725079,0.019698521122336388,0.012200135737657547,-0.008034748956561089,0.0028521015774458647,-0.00245031644590199,-0.006310049910098314,-0.00373665289953351,0.008135923184454441,-0.0090325390920043,-0.0002607999776955694,0.0046803392469882965,-0.01800999790430069,-0.008924789726734161,0.01823682151734829,-0.007351914420723915,-0.019322993233799934,0.012701595202088356,0.0053284624591469765,-0.0064052678644657135,0.019654009491205215,0.00013570864393841475,0.016256112605333328,0.007728443015366793,0.010437853634357452,0.00808533001691103,0.019011886790394783,0.012183984741568565,0.033292051404714584,0.005902435164898634,-0.018925726413726807,-0.00701944762840867,0.011261066421866417,0.005332435946911573,0.0031362916342914104,0.0005442180554382503,-0.0032328530214726925,-0.010592673905193806,-0.018920287489891052,-0.009756236337125301,-0.005785324610769749,-0.030977396294474602,0.001599933486431837,0.00013377821596805006,0.008112323470413685,-0.0063599590212106705,-0.005695757456123829,0.00597459077835083,0.01210800651460886,-0.006559251341968775,0.0007339463336393237,0.011125277727842331,0.022035440430045128,0.017060229554772377,0.01003420352935791,-0.0034310349728912115,0.00637843506410718,0.011094809509813786,-0.013998170383274555,-0.014564729295670986,0.01242771651595831,-0.0036663247738033533,-0.000654135481454432,0.00626980047672987,-0.0076171220280230045,-0.0020285514183342457,0.006653873715549707,0.012656455859541893,-0.01786595582962036,-0.008405892178416252,0.01965014822781086,-0.0021813763305544853,0.010792931541800499,-0.015798313543200493,-0.015769999474287033,-0.006753129884600639,-0.015076013281941414,0.007592670153826475,0.006454171612858772,0.02763102576136589,-0.008400551043450832,-0.0049078394658863544,-0.024386631324887276,0.006857115309685469,0.001914125750772655,-0.01439663302153349,-0.020056629553437233,0.008954518474638462,0.013706443831324577,0.007875348441302776,0.012146084569394588,-0.009473125450313091,0.009648504666984081,0.015645135194063187,0.01922854408621788,0.0068963672965765,0.008811811916530132,0.013530968688428402,-0.017957940697669983,-0.01021209079772234,0.0022633387707173824,-0.007277818396687508,-0.0031573977321386337,-0.11325757950544357,-0.0026099944952875376,0.01439537201076746,-0.004530924838036299,0.001019970397464931,-0.0020006245467811823,-0.004129558335989714,0.015971921384334564,-0.044551171362400055,0.0030149968806654215,0.007847486063838005,-0.01554462406784296,0.007680688984692097,-0.00788731686770916,-0.017942272126674652,-0.000786610587965697,0.005577197298407555,0.009266538545489311,-0.009329116903245449,-0.04451880231499672,-0.0037785109598189592,0.0028084840159863234,-0.009803786873817444,-0.010790380649268627,0.002866531489416957,0.0017853827448561788,0.007238357327878475,-0.007430804427713156,-0.004662869498133659,0.004536635708063841,0.01837938465178013,0.01211519818753004,0.0014415101613849401,-5.029150634072721e-05,0.021934866905212402,-0.010267108678817749,-0.013645731844007969,-0.015742121264338493,0.008256089873611927,-0.04040089249610901,0.07481249421834946,0.007236475590616465,0.009462444111704826,-0.027326276525855064,0.003720212262123823,0.000653174240142107,-0.002285812282934785,-0.0037178313359618187,0.012064619921147823,0.006163128651678562,-4.221188646624796e-05,-0.004891624208539724,-0.009622621349990368,0.0006778354290872812,0.013634954579174519,-0.020278330892324448,-0.004124345723539591,0.007662141229957342,0.018916331231594086,-0.0036245116498321295,0.01430609729140997,-0.01053135097026825,-0.012238960713148117,-0.016030864790081978,0.002648538677021861,0.014399755746126175,-0.008265534415841103,0.017143085598945618,-0.014470246620476246,-5.842742757522501e-05,-0.004861831199377775,-0.015087821520864964,-0.006019762251526117,0.01629151962697506,0.010227116756141186,-0.003751903073862195,-0.01222227606922388,0.0076263234950602055,0.042506661266088486,-0.01409455481916666,-0.0125817796215415,0.006965314969420433,-0.1917276829481125,0.00950542837381363,-0.01586632803082466,0.0023973588831722736,0.005743181332945824,-0.0027462500147521496,0.013118598610162735,0.011540125124156475,-4.4238830014364794e-05,0.0049981833435595036,0.010282487608492374,0.0003759496030397713,0.01399040874093771,0.018821081146597862,-0.014726671390235424,0.004507406149059534,0.011466688476502895,-0.005345562938600779,0.003956358879804611,-0.0034813869278877974,-0.0006390218622982502,-0.012699902057647705,0.006115961819887161,-0.00699468981474638,-0.00933891348540783,0.0034024324268102646,0.0066421241499483585,-0.002772600157186389,-0.00560781080275774,0.0124791469424963,0.008322587236762047,-0.009324386715888977,0.019184015691280365,-0.01484056655317545,0.004880982916802168,0.009200002998113632,-0.004697439726442099,-0.0016762494342401624,0.005595938302576542,0.0051397476345300674,0.015112820081412792,0.0016515520401299,0.0027893949300050735,0.004518795292824507,0.02610747143626213,0.010790864005684853,-0.00240150885656476,0.0018596394220367074,-0.00877827126532793,0.016919050365686417,-0.006034755613654852,0.004655871074646711,-0.007221192587167025,-0.010618927888572216,-0.010135614313185215,0.0057146274484694,-0.0011658620787784457,8.326552051585168e-05,-0.0037010847590863705,0.007693116553127766,-0.011633782647550106,-0.0017288855742663145,0.008993348106741905,0.006360128056257963,-0.006610793061554432,0.02352437563240528,0.001936598913744092,-0.011150550097227097,-0.01644146628677845,0.0009796085068956017,0.0030192439444363117,-0.0053696841932833195,0.013059624470770359,-0.0033805544953793287,0.016168439760804176,0.0018524626502767205,0.012617220170795918,0.005636119283735752,-0.016038715839385986,0.010487047955393791,-0.007545631844550371,-0.001429348485544324,-0.0017839670181274414,-0.008450678549706936,0.005330666434019804,-0.02991759404540062,0.00345455389469862,0.018851209431886673,-0.009807764552533627,0.027462579309940338,0.007071391679346561,0.0019209625897929072,-0.018841171637177467,-0.005503535736352205,0.02069077454507351,-0.020384222269058228,0.00936795026063919,0.007733526639640331,-0.009904591366648674,-0.004870839882642031,-0.03102888911962509,0.010977471247315407,0.015817424282431602,0.0011372757144272327,0.0072667705826461315,0.00784523319453001,-0.003772204741835594,0.015585226006805897,0.006962628103792667,-0.005917835980653763,-0.004866400267928839,-0.002367018721997738,0.005616626236587763,0.008822798728942871,-0.012629799544811249,-0.011987242847681046,0.0032996777445077896,0.0023828642442822456,0.012849369086325169,0.010437403805553913,0.008191507309675217,0.014551647007465363,-0.00907558761537075,-0.012082315981388092,-0.01734895631670952,-0.025283891707658768,0.011902658268809319,0.01442468911409378,-0.00960622914135456,0.009892510250210762,0.006284326780587435,0.09945326298475266,-0.000902246858458966,0.010209871456027031,0.006395020522177219,-0.014969841577112675,0.006021085660904646,0.005478468257933855,0.006624804809689522,-0.005861262790858746,0.018376680091023445,-0.005344887264072895,-0.008701054379343987,0.017867742106318474,0.02290046401321888,0.004558425396680832,-0.0031763159204274416,0.009653178043663502,0.017748555168509483,0.0004191588668618351,-0.020645441487431526,-0.0037479782477021217,0.01151856780052185,-0.018366899341344833,0.013412505388259888,-0.006302890833467245,0.006716001313179731,-0.00566723570227623,0.021751975640654564,-0.009203510358929634,-0.005479597952216864,-0.0036258467007428408,0.011007815599441528,-0.019736887887120247,0.0033232851419597864,-0.00348482932895422,0.005073791369795799,0.017230041325092316,0.020670218393206596,0.004283766727894545,-0.0009454562095925212,0.002031994052231312,-0.017311764881014824,-0.013582253828644753,-0.012368597090244293,0.010673816315829754,-0.0031707175076007843,0.008417531847953796,-0.004093330819159746,-0.01342865638434887,0.006839676760137081,0.007039966061711311,0.002886531176045537,-0.010179306380450726,0.01376741286367178,0.003229884896427393,-0.002050425624474883,-0.006090544629842043,-0.01241382211446762,-0.004899153020232916,-0.007758493069559336,-0.007976759225130081,-0.01766863465309143,0.0025243479758501053,0.0038350399117916822,0.011882581748068333,0.004422273952513933,-0.03836751729249954,-0.01081705279648304,-0.007251629140228033,-0.007358638569712639,0.007515196222811937,0.021443774923682213,-0.011086410842835903,0.003115957835689187,0.01913968101143837,0.023567553609609604,0.0044838543981313705,0.002975921845063567,-0.01662723533809185,-0.006301764864474535,0.011563225649297237,-0.007714479696005583,0.007416438311338425,-0.035197507590055466,0.009823915548622608,-0.017413947731256485,0.011747097596526146,-0.0038893171586096287,0.021576901897788048,0.01757732592523098,0.013345262035727501,-0.006837489083409309,0.029992317780852318,-0.011094197630882263,0.010682325810194016,0.002443913836032152,-0.0005208277725614607,-0.01606852374970913,0.010624848306179047,0.0047839065082371235,0.01419053040444851,-0.01350423227995634,0.012274585664272308,0.012537653557956219,0.007614258676767349,-0.0039986432529985905,0.010640677064657211,-0.0038547625299543142,-0.006087520159780979,0.027305202558636665,0.006098201964050531,-0.00494043156504631,0.004934415221214294,-0.01824975572526455,0.001602957840077579,0.026787754148244858,0.005400836933404207,0.008201074786484241,0.022710701450705528,0.005333361215889454,0.007449979893863201,-0.00023634797253180295,-0.011554860509932041,0.00011505313159432262,0.006364085711538792,0.0009316215291619301,0.012276645749807358,-0.002286005299538374,0.007153740152716637,-0.00578177347779274,-0.003366011893376708,0.016108853742480278,-0.007560239173471928,-0.012466534040868282,5.5177883041324094e-05,0.013790159486234188,-0.012926618568599224,1.878943839983549e-05,0.0008286013035103679,-0.0036813300102949142,-0.0005811856244690716,-0.0008696871809661388,-0.008247340098023415,0.02868564799427986,-0.014315041713416576,-0.017415814101696014,0.006972618401050568,-0.024270612746477127,-0.009373226203024387,0.0051077669486403465,0.0038382895290851593,-0.01722528040409088,0.015512949787080288,0.01026356965303421,0.00711700227111578,-0.010315561667084694,0.01249308604747057,0.014615736901760101,-0.002677438547834754,0.005468305200338364,-0.005088237579911947,-0.018737059086561203,-0.003193721640855074,0.0038784947246313095,0.0009255004115402699,0.006019891239702702,0.0115288645029068,-0.018515832722187042,-0.005315995309501886,0.0148364482447505,0.009229088202118874,-0.002652656752616167,0.005572419613599777,0.007090028841048479,-0.00805481243878603,0.027019791305065155,-0.005165357608348131,0.01384897343814373,-0.01675380766391754,0.014895391650497913,0.001922378083691001,-0.007131235208362341,0.010457383468747139,-0.0060896435752511024,-0.0035761059261858463,-0.017283009365200996,0.013179706409573555,0.01639494299888611,0.0069476836360991,-0.010041441768407822,-0.004489645827561617,-0.01367124542593956,-0.0003028188075404614,0.012466919608414173,-0.010653103701770306,0.008282281458377838,0.003187681082636118,-0.01343492977321148,-0.010245668701827526,-0.011471674777567387,-0.01613684557378292,-0.0010712954681366682,-0.0027505853213369846,-0.001911632250994444,-0.0011440966045483947,-0.02027985267341137,-0.003082658164203167,-0.0005120121641084552,-0.004386079031974077,-0.010168688371777534,0.0036431557964533567,0.006260099820792675,-0.010663633234798908,-0.002148623578250408,-0.002349805785343051,0.0030768970027565956,-0.0034179803915321827,-0.008466539904475212,-0.011844230815768242,-0.005494784563779831,0.0010436181910336018,0.011641600169241428,-0.011137792840600014,7.610687316628173e-05,0.005389544181525707,-0.023192087188363075,-0.005416119936853647,-0.009617231786251068,0.008793344721198082,-0.024386076256632805,0.020657410845160484,5.134117236593738e-05,-0.007362756412476301,-0.009800750762224197,0.006533399689942598,-0.010050579905509949,0.006684471387416124,0.011441572569310665,0.006047689355909824,0.016310229897499084,-0.005246692802757025,0.007157488260418177,0.0017344196094200015,-0.00866750068962574,0.0006803951691836119,0.00713065592572093,-0.0014674743870273232,0.0203915573656559,-0.005685457959771156,-0.007061901036649942,-0.016780640929937363,0.001550675486214459,-0.008510038256645203,-0.011533658020198345,-0.008761588484048843,0.022064397111535072,-0.0017128309700638056,0.0062705883756279945,0.0048079160042107105,0.018406344577670097,0.010051971301436424,0.003991404082626104,0.012091951444745064,-0.005227489396929741,-0.0035770712420344353,-0.009186764247715473,-0.0038295702543109655,-0.00698986416682601,0.012210141867399216,0.005487545393407345,-0.0013136116322129965,0.0018605402437970042,-0.011810770258307457,-0.001065592747181654,0.0004330579249653965,0.024547435343265533,-0.0043790326453745365,-0.0002492174389772117,-0.0189106035977602,-0.010918785817921162,0.020448731258511543,0.007792806718498468,-0.002034664386883378,0.008813790045678616,-0.01989891566336155,0.001182962441816926,0.000261572131421417,-0.0074978540651500225,0.0019776527769863605,-0.011139015667140484,-0.02664639614522457,0.0028707943856716156,0.007007550913840532,-0.017508666962385178,-0.014156038872897625,-0.02033647708594799,0.016214512288570404,0.006000136490911245,-0.016533177345991135,0.018597586080431938,0.005563668441027403,-0.00725555419921875,0.01448176521807909,0.016186457127332687,-0.016622057184576988,0.007171966601163149,0.009879093617200851,0.014025414362549782,0.015332052484154701,0.018447238951921463,0.01657157577574253,-0.01883309707045555,0.0012578627793118358,-0.01160209160298109,-0.0029103304259479046,-0.024813447147607803,-0.008269749581813812,0.019136399030685425,0.12509235739707947,0.00992282573133707,-0.010059620253741741,-0.006295362021774054,-0.009466594085097313,-0.005341983400285244,-0.006175258196890354,-0.00834791548550129,0.0037003285251557827,-0.009935236535966396,-0.022054295986890793,-0.021636681631207466,0.00747463246807456,0.0023884624242782593,0.0020293877460062504,0.000621370563749224,-0.010186834260821342,0.0025970444548875093,0.004555682651698589,0.010875705629587173,-0.00799268577247858,-0.010559020563960075,-0.018151158466935158,0.006607222370803356,0.00013699558621738106,0.0032064514234662056,-0.01213186327368021,0.017665095627307892,-0.001385656651109457,-0.013753159902989864,-0.0032455134205520153,0.004236889537423849,0.011882774531841278,-0.014331771992146969,0.007972095161676407,0.0015528311487287283,0.0077825915068387985,0.0031973575241863728,0.007028214633464813,-0.014710456132888794,0.019549252465367317,-0.013456358574330807,0.006737617775797844,-0.015732519328594208,0.0006138741155155003,0.0037009399384260178,0.011282256804406643,0.010245632380247116,0.002517430577427149,0.007911423221230507,0.00890109408646822,-0.010392270050942898,-0.017399711534380913,-0.02358563430607319,-0.006632172502577305,0.010217915289103985,-0.022281570360064507,0.007806669920682907,0.013242524117231369,-0.0033365730196237564,0.026809824630618095,-0.013774974271655083,-0.00872904434800148,-0.010284706950187683,-0.014805947430431843,0.015970248728990555,0.017862962558865547,0.015086662955582142,0.0027441910933703184,0.010856385342776775,-0.004200211260467768,-0.0081545514985919,0.0031795732211321592,-0.026753583922982216,0.014192008413374424,-0.012117899954319,-0.0035813823342323303,0.015963943675160408,-0.0860016718506813,0.03140305355191231,0.007273109629750252,-0.00939896609634161,0.008446688763797283,-0.00541621632874012,-0.0522768460214138,-0.0012892642989754677,-0.009854674339294434,-0.0076980385929346085,-0.015288103371858597,-0.03279374539852142,-0.014441356062889099,-0.005670452956110239,-0.0029624251183122396,-0.012520995922386646,-0.0102844825014472,-0.017415877431631088,-0.015840580686926842,-0.013365293852984905,-0.009166606701910496,-0.005349005106836557,-0.005249958485364914,0.019897757098078728,-0.007069654297083616,-0.009444724768400192,0.004441514145582914,-0.01018715649843216,0.009931439533829689,0.002962167840451002,-0.013154460117220879,0.014917655847966671,-0.015001467429101467,0.009532036259770393,-0.0044509246945381165,0.028517216444015503,0.00990370661020279,-0.010221325792372227,-0.010877507738769054,0.0023901837412267923,0.02150103636085987,-0.014040149748325348,-0.0007246803143061697,0.00785189401358366,0.0014458857476711273,-0.0006708737928420305,0.004349204711616039,-0.01244916021823883,-0.01190697681158781,-0.1309737116098404,-0.0030378401279449463,0.005152037832885981,-0.025020644068717957,0.013737556524574757,0.01354216504842043,-0.010803540237247944,-0.020594704896211624,-0.010123742744326591,-0.005482333246618509,0.007814539596438408,0.0062471660785377026,0.011471273377537727,0.014933951199054718,0.010366315953433514,-0.017068468034267426,0.0075530968606472015,0.0021459211129695177,-0.005174430552870035,0.004797837696969509,-0.0006980726611800492,-0.01761162281036377,-0.011748763732612133,0.007687899749726057,-0.015306426212191582,0.007811580318957567,-0.004673641175031662,0.019404791295528412,0.006644575856626034,-0.009581189602613449,0.01846865750849247,-0.00799687672406435,-0.008734514936804771,0.025797318667173386,0.004079817328602076,0.01512935757637024,-0.0006804736331105232,-0.0038689833600074053,0.006711303722113371,-0.014750850386917591,0.016202479600906372,0.01031462848186493,-0.005430308170616627,0.01708185113966465,0.008559875190258026,-0.005445751361548901,-0.0028198380023241043,-0.0038498397916555405,-0.006423091981559992,0.013393329456448555,0.008289198391139507,0.019474737346172333,0.013462373986840248,-0.009793463163077831,-0.013543033972382545,0.03380116820335388,0.057620640844106674,0.0037551848217844963,0.01428164541721344,0.011203941889107227,-0.00013776373816654086,-0.007206891197711229,0.011069182306528091,-0.0032131224870681763,0.009809983894228935,0.006570447236299515,-0.002480398863554001,0.022422587499022484,0.011351908557116985,-0.01595130003988743,-0.019222430884838104,0.00509705301374197,-0.006570335011929274,0.0017189440550282598,0.027080731466412544,-0.011916235089302063,0.0015000663697719574,-0.0020198484417051077,-0.02209283970296383,0.006771082524210215,0.0002977755793835968,-0.019696606323122978,0.008564154617488384,-0.0007474914309568703,0.011921319179236889,0.009810338728129864,0.014718177728354931,0.0014345606323331594,0.008807356469333172,-0.006630355026572943,-0.003958745859563351,-0.009559383615851402,-0.005430855322629213,-0.014630086719989777,-0.011925501748919487,0.0004732106754090637,0.018642853945493698,-0.013681734912097454,0.010839325375854969,-0.014961443841457367,0.0016361128073185682,0.0032435106113553047,-0.002405848354101181,-0.018609875813126564,0.0033618290908634663,0.011865722015500069,-0.012829582206904888,0.008958829566836357,-0.011033131740987301,0.007112349383533001,-0.007317069917917252,-0.003843147773295641,0.015338101424276829,0.0060599129647016525,0.013022753410041332,0.022979997098445892,-0.010455581359565258,0.003293846268206835,0.011678189970552921,0.03189416974782944,-0.0003863417077809572,0.006824394688010216,-0.008517374284565449,0.012291766703128815,-0.008964218199253082,0.007173221092671156,0.019597060978412628,0.0208904929459095,-0.008607679978013039,0.02034304104745388,0.010004634968936443,0.011900341138243675,-0.00043498832383193076,0.0033996535930782557,-0.002569137839600444,0.009322158992290497,-0.002651530783623457,-0.008777949027717113,-0.005856899078935385,-0.013607734814286232,0.0010277243563905358,-0.011572104878723621,-0.023325929418206215,0.008436039090156555,0.0016878400929272175,-0.0035754949785768986,0.010810618288815022,0.020025212317705154,-0.009496903046965599,0.01064186729490757,0.0021814408246427774,-0.0061418297700583935,-0.006570986472070217,0.01253622304648161,0.01944899745285511,-0.010414046235382557,0.00017785617092158645,0.006716644857078791,0.011308281682431698,0.014264336787164211,-0.0031749242916703224,-0.020774956792593002,-0.0003114172432105988,0.011388715356588364,-0.009031891822814941,-0.006522138603031635,0.018276477232575417,0.0024473723024129868,0.002980136778205633,-0.007986669428646564,0.010007386095821857,0.009231405332684517,-0.018392913043498993,-0.020028775557875633,0.012274328619241714,-0.008668269030749798,0.0041609592735767365,-0.0037708855234086514,-0.009803260676562786,-0.004945358261466026,-0.01740073226392269,0.0035423238296061754,-0.007416149135679007,0.023602621629834175,0.005355633329600096,-0.0019859694875776768,0.01988109014928341,7.979076144692954e-06,-0.006595607381314039,0.0053070830181241035,0.008229612372815609,0.016438249498605728,0.006289506796747446,0.00754022691398859,0.011281898245215416,0.00024167270748876035,0.006314409431070089,-0.0031186926644295454,-0.02108895592391491,-0.013352083042263985,0.020173614844679832,0.008024762384593487,0.013543741777539253,-0.015686606988310814,-0.008190031163394451,0.015606686472892761,-0.008021931163966656,-0.015871604904532433,0.0037902863696217537,0.0008586193434894085,0.003796238452196121,-0.010971165262162685,0.007283883169293404,-0.016522156074643135,0.0055426545441150665,-0.018035799264907837,-0.009387576021254063,-0.00015417633403558284,-0.009344720281660557,-0.005082639399915934,0.007296253461390734,-0.009880026802420616,-0.002254636026918888,0.02115420438349247,-0.00485372357070446,0.004400492645800114,-0.00884152390062809,-0.006040804088115692,0.011755109764635563,0.008026177994906902,-0.006253858096897602,-0.0029635189566761255,0.007403810508549213,0.0043754614889621735,0.026068542152643204,-0.024823419749736786,-0.004859900567680597,0.0077138361521065235,0.0007009119726717472,-0.018028592690825462,-0.011082421988248825,-0.007141128182411194,-0.01778709888458252,0.009043511003255844,0.0008742235950194299,0.019595323130488396,-0.00226938771083951,-0.0021313303150236607,0.0028745909221470356,0.013393265195190907,0.0035802884958684444,-0.0015817874809727073,0.006639556493610144,0.006195977795869112,-0.007812898606061935,-0.008897827938199043,-0.012519138865172863,0.014377216808497906,0.00478403503075242,-0.004690281115472317,0.003118644468486309,0.027247516438364983,-0.002435001777485013,0.033513087779283524,0.01822897233068943,0.007350771687924862,0.0011077403323724866,0.013501819223165512,-0.015879904851317406,0.013183299452066422,0.011308056302368641,-0.0003690966113936156,-5.669004895025864e-05,0.006077144294977188,-0.0071005732752382755,0.005103584378957748,0.012177292257547379,-0.0015176330925896764,0.00743842963129282,0.006680489517748356,0.004452131222933531,0.004653377924114466,-0.008840574882924557,-0.0031223606783896685,-0.013772077858448029,-0.005994860082864761,0.0052159992046654224,0.00597047246992588,-0.004418735392391682,-0.009556038305163383,-0.005633131135255098,0.02587483637034893,-0.002589789219200611,-0.0176318921148777,-0.009988966397941113,-0.015307571738958359,-0.009621800854802132,-0.002565787872299552,-0.01531350426375866,0.014097933657467365,-0.0033172364346683025,0.001826854539103806,0.0018190363189205527,-0.008359553292393684,-0.0038599425461143255,-0.004618598148226738,-0.0021358828525990248,-0.0039221663028001785,-0.0034684045240283012,-0.004433149006217718,0.006080731749534607,-0.0017949383473023772,-0.008630593307316303,0.001273048692382872,-0.019467659294605255,-6.12587173236534e-05,-0.018115075305104256,-0.006602621171623468,-0.007384441327303648,-0.007939839735627174,0.0019286199240013957,0.0008089773473329842,-0.01783713512122631,0.010118434205651283,-0.014237920753657818,0.01597065106034279,0.016588177531957626,-0.01785440556704998,0.01155418436974287,-0.005966603755950928,-0.014077438972890377,-0.013903025537729263,-0.002557036466896534,-0.021007491275668144,-0.005378428380936384,0.012218442745506763,0.004273728467524052,0.011610778979957104,-0.004312143661081791,0.01642666570842266,-0.023566925898194313,0.013862889260053635,0.015911821275949478,0.004173909313976765,-0.024028481915593147,-0.01222963910549879,-0.005391822662204504,0.011719332076609135,-0.007083456497639418,-0.0073945121839642525,0.010108668357133865,0.013066895306110382,-0.0004766210913658142,-0.006762267090380192,-0.0007032324792817235,0.0023309518583118916,0.012527922168374062,-0.006683377083390951,0.012418627738952637,-0.008594752289354801,-0.0089180339127779,-0.0018390804762020707,-0.01272482518106699,0.015199174173176289,-0.012042034417390823,-0.010652774013578892,0.001955002313479781,0.009363831952214241,-0.009031509980559349,-0.0028586569242179394,-0.0013132980093359947,0.009787592105567455,0.008148052729666233,0.004363750107586384,0.009258558973670006,-0.024081429466605186,0.01084060501307249,0.02108844183385372,-0.01939285360276699,0.011464710347354412,-0.010239985771477222,-0.009829654358327389,0.02925250120460987,-0.006770503241568804,-0.0068392264656722546,0.0012964068446308374,-0.016846660524606705,0.0068872300907969475,-0.003937834873795509,-8.339421765413135e-05,0.008675314486026764,-0.005402928218245506,-0.009232563897967339,0.011987275443971157,0.006109446752816439,-0.006341531407088041,0.007804907858371735,-0.007662084884941578,0.006093183066695929,-0.018207769840955734,-0.006304789334535599,0.000968299456872046,0.011293482035398483,0.0006706284475512803,0.00998291838914156,-0.016655774787068367,0.004729790613055229,0.008077752776443958,-0.0064179119653999805,-0.006763167679309845,0.0055464874021708965,-0.006630998104810715,-0.006346454378217459,0.0029069576412439346,0.004286420997232199,-0.00612212298437953,0.009613017551600933,-0.007194488774985075,-0.014121548272669315,-0.013963254168629646,0.008268116973340511,0.018683167174458504,0.00021566831856034696,0.010583395138382912,0.0023251124657690525,0.005577534902840853,-0.005223962478339672,-0.010808792896568775,-0.00891019869595766,0.0025711446069180965,-0.009238084778189659,0.00847254041582346,0.002356433542445302,-0.020508840680122375,0.008203793317079544,-0.013110458850860596,-0.00429300032556057,0.00894743949174881,-0.0010654800571501255,0.007953747175633907,0.0008857498760335147,0.008226757869124413,0.006239090580493212,-0.003030576976016164,-0.011644785292446613,-0.016018863767385483,0.0014197607524693012,0.012671319767832756,-0.014869586564600468,-0.011633380316197872,-0.0008804009412415326,0.005208792630583048,-0.009140313602983952,-0.004907278809696436,-0.01574484072625637,0.007207204587757587,-0.025614989921450615,0.010377657599747181,0.005622417200356722,0.020156607031822205,-8.534072549082339e-05,-0.013232074677944183,0.0025512452702969313,0.0074208625592291355,0.003769534407183528,0.006363023538142443,0.001976124243810773,-0.009836303070187569,0.014816982671618462,-0.02623211219906807,-0.013312103226780891,0.018329545855522156,0.011043942533433437,0.004413313698023558,-0.0026370524428784847,-0.006824623793363571,-0.01342408824712038,0.01530361082404852,0.02297188900411129,-0.015759512782096863,-0.0038370348047465086,0.008708260953426361,0.0386798270046711,0.006922588218003511,-0.014513103291392326,0.006315784528851509,0.0011656669666990638,-0.00011241488391533494,-0.0043263561092317104,0.006935876328498125,0.01871299184858799,-0.0018523683538660407,0.01645565964281559,0.0006411654176190495,-0.017343293875455856,0.01558641716837883,0.003914637491106987,-0.003911966923624277,0.010716164484620094,0.010333998128771782,0.009289140813052654,0.002327702473849058,-0.0016474217409268022,0.0085306940600276,-0.006147765554487705,-0.0027541646268218756,0.012298844754695892,-0.011853464879095554,0.0022197917569428682,0.009226707741618156,0.02173178642988205,-0.017738966271281242,-0.010917370207607746,-0.0029402251821011305,0.0004863214853685349,-0.0067732385359704494,-0.009347519837319851,-0.0026199843268841505,0.00044122201506979764,0.007049706764519215,-0.005566982086747885,-0.009083359502255917,0.005341717973351479,0.0016353566898033023,0.0075265211053192616,-0.025540797039866447,-0.00833797361701727,-0.00534829730167985,-0.004227929282933474,0.016433872282505035,0.006095499265938997,0.0034416201524436474,0.006703711114823818,-0.013493518345057964,-0.00048759233322925866,0.02160598710179329,-0.018758028745651245,-0.013188640587031841,0.00872473418712616,0.01274280995130539,-0.002263290574774146,-0.0006550966063514352,-0.01119509432464838,-0.010811157524585724,-0.007531395647674799,0.0025357375852763653,0.01623639091849327,0.012533069588243961,-0.11452934145927429,-0.014385758899152279,-0.0036055126693099737,0.002186845988035202,0.013855954632163048,-0.0006583944195881486,0.0048728990368545055,0.009528513066470623,0.003839930286630988,0.01954481191933155,0.001959699671715498,-0.00801488570868969,0.01553120743483305,0.010433783754706383,0.00287243933416903,0.0030284454114735126,0.0027071910444647074,0.005127111449837685,0.007968137040734291,0.004281257279217243,-0.011975499801337719,-0.017328623682260513,0.008220185525715351,0.007401622831821442,-0.013764807023108006,0.007864666171371937,-0.004687312990427017,-0.004217983223497868,-0.01190197467803955,0.005709093064069748,0.012869670987129211,-0.013801033608615398,-0.011998728848993778,0.20357556641101837,-0.0030479426495730877,0.012771195732057095,-0.0171239972114563,0.005747669842094183,0.00899829063564539,-0.014829105697572231,0.00494075333699584,-0.008008965291082859,-0.0036376866046339273,-0.033662255853414536,0.0065314690582454205,-0.009848415851593018,0.013626010157167912,0.012002847157418728,-0.013834439218044281,0.02108149044215679,0.016931405290961266,-0.0017394707538187504,-0.00963470246642828,-0.005704395938664675,0.01754046231508255,-0.015337469056248665,0.015215389430522919,-0.005915905814617872,-0.025276893749833107,-0.005014732480049133,-0.00463339826092124,-0.020541712641716003,-0.001968644093722105,0.000676644966006279,0.01785305328667164,-0.011794249527156353,0.016294624656438828,-0.004089083522558212,0.006442975252866745,-0.02364637888967991,-0.010055324994027615,0.008496284484863281,0.005891228560358286,0.010857462882995605,-0.0347641259431839,-0.014917171560227871,0.017434941604733467,-0.01820305548608303,-0.02300403080880642,-0.01460286695510149,-0.026439635083079338,-0.005786696448922157,0.005840812344104052,-0.002880639396607876,0.005296160001307726,-0.004211021587252617,-0.002037527970969677,-0.010035361163318157,0.004914330784231424,0.004394669085741043,0.005622674711048603,0.0011111185885965824,0.009060111828148365,-0.01080778706818819,-0.014376429840922356,-0.008422542363405228,0.0036981890443712473,-0.026923397555947304,0.009801522828638554,-0.0014322763308882713,-0.013493984937667847,0.012008155696094036,0.012425931170582771,0.009741486981511116,0.02373787946999073,0.0018142102053388953,-0.0050240508280694485,0.01613137498497963,0.005036276765167713,0.0027613716665655375,0.005145667586475611,-0.005073678679764271,0.00631151394918561,0.015935149043798447,0.005443435162305832,-0.0074535515159368515,0.012360554188489914,0.009225227870047092,0.010121893137693405,0.0003564523358363658,0.0020175480749458075,0.0005545940366573632,-0.018256383016705513,-0.0015494207618758082,-0.004463328048586845,0.010256974026560783,0.005540004465728998,-0.005248623434454203,0.005901942495256662,0.010503585450351238,-0.008990907110273838,0.008495476096868515,-0.029623478651046753,-0.0010746014304459095,0.010479615069925785,0.007128741126507521,-0.004881907254457474,-0.012746831402182579,-0.005546809174120426,-0.004563066177070141,0.0002746024983935058,-0.012642459943890572,-0.003734111087396741,0.01777506433427334,0.0049340128898620605,-0.0012290994636714458,-0.00021181550982873887,0.0020156176760792732,0.0010072377044707537,0.003468742361292243,-0.003944575320929289,0.014315459877252579,-0.005033606663346291,0.004686838481575251,-0.012386228889226913,0.0018407534807920456,0.004675609990954399,-0.0087699294090271,-0.005062884651124477,-0.0077690305188298225,0.00480366125702858,-0.012847527861595154,-0.007804791443049908,-0.0020366229582577944,0.010552520863711834,0.0009618164622224867,-0.02200361341238022,-0.02055400423705578,0.007025834172964096,0.005628401413559914,-0.003323606913909316,-0.00350605184212327,0.006432036403566599,0.004809271544218063,0.010274733416736126,0.04477909207344055,-0.009266168810427189,-0.014458194375038147,0.003407451556995511,-0.003966630436480045,0.00690626073628664,-0.005162558518350124,-0.017314080148935318,-0.0033658831380307674,-0.019236072897911072,-0.010986302979290485,-0.009487057104706764,-0.0126802958548069,0.009735309518873692,0.04154672846198082,-0.018142199143767357,0.002596642356365919,-0.0076661063358187675,0.013936100527644157,0.058171678334474564,-0.025674721226096153,-0.006219496950507164,-0.014702396467328072,0.007355244364589453,-0.01217672135680914,-0.01009633019566536,0.008379188366234303,-0.00898730382323265,-0.0017007015412673354,0.003610322717577219,0.0026148527394980192,0.0058074044063687325,-0.016003387048840523,-0.011510750278830528,0.0013994108885526657,-0.005675825756043196,-0.010906624607741833,0.003757855389267206,0.008256155997514725,0.0037957236636430025,0.0004637596430256963,0.0059378482401371,-0.006037457846105099,-0.018181998282670975,0.0013030506670475006,0.007541135419160128,0.009224391542375088,0.010982869192957878,-0.0036199912428855896,-0.002958113793283701,0.01651797443628311,-0.03149764612317085,0.004628603812307119,0.00334406946785748,-0.007923029363155365,0.015490380115807056,0.020828863605856895,0.016824204474687576,-0.0038670848589390516,0.014724436216056347,0.000400498160161078,0.0663076639175415,0.00567030580714345,-0.013410317711532116,0.008589716628193855,-0.008427352644503117,-0.01424303650856018,0.0008962303982116282,-0.009365360252559185,0.008820024318993092,0.013941312208771706,-0.007390265353024006,0.015612092800438404,0.008377837017178535,-0.006962129846215248,0.01604386232793331,0.004204136785119772,0.0069089229218661785,-0.0185052789747715,-0.013314954936504364,0.007275469601154327,0.014722811058163643,0.008437100797891617,0.011726523749530315,0.016620544716715813,0.015615695156157017,0.0120353102684021,0.006396838463842869,-0.008448812179267406,-0.00602632574737072,0.010790380649268627,0.002144247991964221,-0.014843912795186043,0.013109751045703888,-0.0005983744049444795,-0.01191713660955429,-0.0060539147816598415,0.007560625206679106,0.018343864008784294,-0.02141418308019638,-0.0038201757706701756,-0.0008210405358113348,0.0037896588910371065,0.00903385877609253,0.02255813404917717,0.0149000883102417,0.010207773186266422,0.01298686396330595,0.01658656820654869,-0.009689725004136562,-0.000968685548286885,-0.0354095958173275,-0.0020211192313581705,0.0172839667648077,0.017595110461115837,-0.007312276400625706,-0.009096597321331501,-0.012832960113883018,0.006029736716300249,0.01993134617805481,-0.007445869967341423,-0.013995345681905746,-0.021392418071627617,0.013174227438867092,0.0006699688965454698,0.0026909918524324894,0.0032831323333084583,0.012930993922054768,0.0012651460710912943,0.000811227539088577,0.01763002574443817,-0.00523826340213418,0.016636181622743607,-0.011958190239965916,-0.00934743881225586,0.011710581369698048,-0.009352635592222214,0.001517037977464497,0.022132251411676407,-0.0027835392393171787,-0.021134112030267715,0.000661684141959995,0.0020901961252093315,0.008411427959799767,-0.02320259064435959,-0.023216569796204567,-0.02040291577577591,-0.0019324647728353739,-0.012253865599632263,-0.012067129835486412,-0.012556578032672405,-0.006384226027876139,0.008578809909522533,-0.0006862648879177868,0.018786733970046043,0.008309703320264816,-0.004579378291964531,0.008779493160545826,-0.012430795468389988,0.010612075217068195,0.006497509777545929,0.00468828622251749,0.020637301728129387,0.014828919433057308,0.008801830001175404,-0.0012163587380200624,0.011090272106230259,0.00605464493855834,-0.00599315483123064,0.003595965448766947,0.0026772695127874613,0.007111930754035711,-0.0021474009845405817,-0.15517501533031464,-0.007093977648764849,0.016207048669457436,-0.003689244855195284,0.02290702797472477,-0.024147450923919678,0.02058466523885727,-0.003728344105184078,0.0020039579831063747,0.0036031962372362614,-0.00701624620705843,0.001598936039954424,-0.015112241730093956,-0.026839423924684525,-0.0005213304539211094,0.04432762786746025,0.0021426393650472164,0.008228357881307602,0.0006260357331484556,-0.0051366910338401794,0.0046644131653010845,-0.0015309208538383245,0.007084615062922239,-0.010650690644979477,-0.01891385205090046,-0.017962105572223663,-0.019904641434550285,-0.003021359210833907,0.00939719658344984,0.014427713118493557,0.0003639488131739199,0.01590440608561039,-0.007913827896118164,-0.008794532157480717,-0.004160219803452492,-0.00011183575406903401,-0.023288607597351074,0.001976816216483712,0.022937526926398277,-0.009748597629368305,-0.014059019275009632,-0.022420817986130714,0.014181907288730145,0.0013818360166624188,0.0023023937828838825,-0.007540484424680471,0.01842080056667328,0.006028867792338133,-0.022552955895662308,-0.005644746124744415,-0.0043883309699594975,-0.004599744454026222,-0.008561484515666962,0.014006786048412323,-0.011542826890945435,-0.009602931328117847,-0.036284975707530975,0.0013754897518083453,0.012572064064443111,0.006309454329311848,-0.0002941721468232572,-0.004653667565435171,-0.013862421736121178,0.004336177371442318,0.010433993302285671,0.009525666013360023,-0.006532643456012011,-0.0015942708123475313,0.014698229730129242,0.013635436072945595,0.01483591366559267,0.004928945563733578,0.011660551652312279,0.00346562173217535,-0.009555619210004807,0.01836557686328888,0.011766644194722176,0.005703310016542673,-0.005696287844330072,0.008640498854219913,0.00856035016477108,-0.03719845414161682,0.016891704872250557,0.009445746429264545,-0.0034338664263486862,-0.005024726502597332,-0.016796855255961418,-0.008475210517644882,-0.017073003575205803,0.004128266125917435,0.016665266826748848,0.00954902358353138,0.010982382111251354,-0.008389675989747047,-0.012186558917164803,0.008364107459783554,0.017737936228513718,0.01394137553870678,0.013139929622411728,-0.008969285525381565,-0.01151264924556017,-0.007080208044499159,-0.02486042119562626,0.00451834499835968,0.01454064343124628,-0.0027549047954380512,-0.01847361959517002,0.012725340202450752,0.02681497111916542,0.0022874209098517895,0.0060871499590575695,-0.012228837236762047,-0.01910441741347313,-0.02300979010760784,0.004791234154254198,-0.00982105266302824,-0.007742567453533411,0.01883193850517273,0.0016032794956117868,-0.0007860033656470478,-0.00030844920547679067,0.0010288181947544217,-0.01645890437066555,0.014252045191824436,-0.01001357939094305,0.002469572238624096,-0.025139495730400085,-0.007612746674567461,-0.05701448768377304,0.008700916543602943,0.01902882568538189,-0.02189522795379162,0.015759384259581566,0.010229690931737423,-0.013251837342977524,-0.013460122980177402,-0.01524634100496769,0.0020383321680128574,0.014956198632717133,-0.007906491868197918,-0.013498730957508087,0.006993595976382494,0.003018873743712902,0.001712734461762011,0.03202492371201515,0.026156842708587646,0.008240841329097748,-0.017780285328626633,0.006188404746353626,-0.014345478266477585,0.0025132661685347557,0.011938242241740227,-0.00015267223352566361,0.0147481644526124,-0.00812479481101036,-0.0010659064864739776,-0.0005582457524724305,0.006272712256759405,-0.004541509784758091,0.0014816629700362682,-0.02871515043079853,0.0016121916705742478,-0.02394980750977993,0.0008420820813626051,-0.007255136035382748,-0.006515704095363617,-0.005095303524285555,-0.005030743312090635,-0.011658716946840286,0.028127659112215042,0.00975873228162527,0.021014409139752388,-0.0160182137042284,0.008259791880846024,-0.00808415561914444,-0.011482791975140572,-0.0018780268728733063,-0.0016436574514955282,0.01837550289928913,0.0003763035056181252,0.009928029961884022,-0.008596843108534813,-0.0039632199332118034,0.01536337286233902,0.0038513196632266045,0.01520631741732359,-0.012446328997612,0.01358643639832735,-0.01477467454969883,0.0018546526553109288,-0.013842265121638775,-0.0008109700866043568,0.015721803531050682,0.006470515858381987,-0.01047314889729023,-0.017738599330186844,-0.002085148822516203,-0.00151948316488415,0.000500236579682678,-0.011062928475439548,-0.012429083697497845,-0.008604375645518303,-0.0033165609929710627,0.0162813700735569,-0.00872577540576458,0.006237449590116739,0.0014139856211841106,0.00227738288231194,0.007259607780724764,-0.0024163410998880863,-0.000929530244320631,0.01526214275509119,0.0005013305344618857,0.012352321296930313,0.0024202982895076275,-0.004930940456688404,0.005372138228267431,0.013471262529492378,0.011361593380570412,0.020780909806489944,-0.016667872667312622,-0.01875338703393936,-0.0006402565049938858,-0.0038189534097909927,-0.0173107348382473,-0.0007631341577507555,-0.004413474816828966,0.006579649168998003,-0.0007289272034540772,-0.016239607706665993,0.007476409897208214,5.302224599290639e-05,-0.01624462567269802,-0.014696476981043816,-0.0008294378640130162,6.569868855876848e-05,-0.006026261951774359,-0.0035658427514135838,0.00035259153810329735,-0.003949449863284826,0.009364716708660126,-0.010776331648230553,0.002928385278210044,-0.009490063413977623,-0.01819232851266861,0.004032875876873732,-0.0032316383440047503,0.00964342150837183,-0.0010484643280506134,-0.016542362049221992,-0.013282490894198418,-0.02188814990222454,0.014662325382232666,0.003973450977355242,0.01259040366858244,0.003396448213607073,0.0023380222264677286,-0.01695997640490532,0.012070347554981709,0.007248966954648495,0.011380953714251518,-0.009349804371595383,0.005258500576019287,0.01802116073668003,0.00570098590105772,-0.011989140883088112,0.011402743868529797,0.010607988573610783,0.008799505420029163,-0.009475105442106724,0.008064079098403454,-0.012264966033399105,-0.006731090601533651,0.00045869231689721346,-0.014379839412868023,-0.007578159682452679,-0.019541822373867035,0.02880922518670559,-0.01217967364937067,-0.0017422698438167572,0.009241893887519836,0.011424331925809383,-0.0059761349111795425,-0.10590112954378128,0.01093854196369648,-0.019668808206915855,-0.008417797274887562,-0.012183469720184803,-0.015398330055177212,0.022412968799471855,-0.014847170561552048,0.012399098835885525,-0.011321166530251503,-0.020581383258104324,-0.012875880114734173,0.009312482550740242,-0.01491408422589302,0.010381936095654964,0.014163745567202568,-0.00536081288009882,0.0030865189619362354,-0.017042148858308792,0.009154188446700573,0.003824438899755478,0.004048094153404236,-0.005840908735990524,-0.004764570388942957,-0.0011096063535660505,-0.01651327684521675,0.004218435846269131,0.0076619721949100494,0.016768736764788628,-0.010754378512501717,-0.007011130917817354,-0.0018741177627816796,0.004677861928939819,-0.0013004607753828168,0.02279837615787983,0.015664083883166313,-0.003047492355108261,-0.006805235054343939,-0.023204054683446884,0.011979939416050911,-0.01936367340385914,0.020488401874899864,0.0002779807255137712,0.01603945530951023,0.011033518239855766,-0.0034474434796720743,0.003860779106616974,0.0030094629619270563,-0.0025448587257415056,0.016781283542513847,0.0010827252408489585,-0.02335255965590477,0.000616254925262183,-0.0035649340134114027,0.0007393514970317483,-0.008183765225112438,0.0014471083413809538,0.0038755787536501884,0.007099337410181761,-0.012667966075241566,0.006208354607224464,-0.011235825717449188,-0.005788819864392281,-0.013990281149744987,-0.005277065094560385,-0.019661838188767433,-0.011538130231201649,0.011401553638279438,0.0067108855582773685,0.001396434847265482,0.0769028514623642,-0.0029904483817517757,0.002209946746006608,0.009979894384741783,-0.0010606379946693778,-0.016086678951978683,0.007984510622918606,0.018508948385715485,0.0032983184792101383,-0.004930043593049049,0.013569834642112255,1.877335125755053e-05,0.0041457414627075195,-0.0065275197848677635,0.01902691088616848,0.0049742781557142735,-0.008188189007341862,-0.004906102083623409,-0.0191107876598835,0.016605230048298836,-0.017471250146627426,0.010408093221485615,-0.008595138788223267,0.00039457817911170423,0.0075583732686936855,0.01484600454568863,0.011490130797028542,0.0035124020650982857,-0.006972779054194689,0.0128085408359766,0.006472124718129635,-0.011789342388510704,0.006717384327203035,-0.0022378091234713793,0.00325773935765028,0.0053901877254247665,0.008246632292866707,0.0030436997767537832,0.0072782342322170734,0.0012802877463400364,-0.00802643597126007,0.004147414583712816,0.008670682087540627,0.004049904178828001,0.0038673868402838707,0.014705437235534191,0.0026979250833392143,0.001775945769622922,-0.01869085803627968,0.0037806022446602583,0.012721864506602287,0.015738211572170258,-0.008133381605148315,-0.007445990107953548,-0.006062779109925032,0.005171599797904491,-0.007623749785125256,-0.001971603836864233,-0.03202363848686218,0.0014124091248959303,0.00964097585529089,-0.0062558529898524284,0.12542743980884552,-0.023395422846078873,-0.02142343297600746,0.00010404972999822348,0.0040498957969248295,0.009305443614721298,-0.005175766069442034,-0.006316371727734804,0.01862599514424801,0.01787419244647026,0.03209351748228073,-0.013965249061584473,-0.01298594195395708,0.003942033741623163,0.007697572000324726,-0.0037004253827035427,0.001353675965219736,0.004194419831037521,0.038188375532627106,-0.006305979564785957,0.008670156821608543,-0.011301315389573574,0.022354990243911743,0.011309697292745113,-0.006025111768394709,-0.02238098718225956,-0.014605054631829262,0.009788730181753635,-0.02146783284842968,-0.026633543893694878,0.008195299655199051,5.627179052680731e-05,-0.006054638884961605,0.018990008160471916,0.0018300878582522273,-0.006439500488340855,0.0015690467553213239,-0.004935315810143948,-0.005042776465415955,-0.008323850110173225,0.01732305809855461,0.004760194569826126,0.009951967746019363,0.002688618842512369,-0.02490813285112381,0.013938416726887226,-0.008612480014562607,0.017687037587165833,0.0007003569626249373,0.003144141985103488,0.00028641021344810724,0.006280304864048958,0.01704099029302597,-0.031904399394989014,-0.01954682171344757,0.006692659109830856,-0.0029927969444543123,-0.019856123253703117,0.01037242915481329,0.007297733798623085,-0.00034432284883223474,9.271252201870084e-05,3.400759305804968e-05,-0.008098633028566837,-0.017516130581498146,0.0009811046766117215,-0.007083006668835878,-0.013434672728180885,0.006502609234303236,0.00046227165148593485,-0.006619544234126806,-0.011502401903271675,-0.01764489896595478,-0.018358498811721802,-0.016132373362779617,0.01945388875901699,-0.004716904833912849,0.016170112416148186,0.002639401238411665,-0.008305462077260017,-0.030113548040390015,0.014484983868896961,0.049616213887929916,0.0026693870313465595,0.015345823019742966,0.0026869860012084246,0.019824400544166565,0.00838514044880867,0.0023412152659147978,-0.0035702185705304146,-0.007228761445730925,0.009889356791973114,-0.01150357536971569,0.006204118020832539,-0.007316265255212784,0.005138332024216652,-0.004389585927128792,-0.006546832155436277,-0.004268612712621689,0.022032320499420166,-0.014779822900891304,0.011949374340474606,0.0014258417068049312,0.0048449402675032616,0.02138534002006054,-0.0369078628718853,-0.0007908937404863536,-0.009307898581027985,0.009610539302229881,0.010517065413296223,-0.005397812929004431,-0.0021158468443900347,-0.003497409401461482,-0.0037914770655333996,-0.019967637956142426,0.002439747331663966,-0.020455583930015564,-0.006008759140968323,-0.008751148357987404,-0.018866462633013725,0.008806422352790833,-0.0035796293523162603,-0.003078668611124158,-0.004720652941614389,-0.010492903180420399],\"index\":0}],\"model\":\"vicuna-7b-v1.5\",\"usage\":{\"prompt_tokens\":13,\"total_tokens\":13}}" + ] + } + ], + "source": [ + "!curl http://127.0.0.1:8000/v1/embeddings \\\n", + " -H \"Content-Type: application/json\" \\\n", + " -d '{ \\\n", + " \"model\": \"vicuna-7b-v1.5\", \\\n", + " \"input\": \"Hello, can you tell me a joke for me?\" \\\n", + " }'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-U2SZWTghxzc" + }, + "source": [ + "Try text completion with" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "85T5NO7Wh03R", + "outputId": "1a2c9568-2aa3-4a89-ecd8-8af496be1a41" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"id\":\"cmpl-kB3gg4KtgcGdif9V4eNbh6\",\"object\":\"text_completion\",\"created\":1705782008,\"model\":\"vicuna-7b-v1.5\",\"choices\":[{\"index\":0,\"text\":\", there was a little girl named Alice. Alice lived in a small village nestled in a valley\",\"logprobs\":null,\"finish_reason\":\"length\"}],\"usage\":{\"prompt_tokens\":5,\"total_tokens\":24,\"completion_tokens\":19}}" + ] + } + ], + "source": [ + "!curl http://127.0.0.1:8000/v1/completions \\\n", + " -H \"Content-Type: application/json\" \\\n", + " -d '{ \\\n", + " \"model\": \"vicuna-7b-v1.5\", \\\n", + " \"prompt\": \"Once upon a time\", \\\n", + " \"max_tokens\": 20, \\\n", + " \"temperature\": 0.5 \\\n", + " }'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EDxLbQDKVLiQ" + }, + "source": [ + "Try create_embeddings to analyze the prompts!" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bbrFoxgaplhK", + "outputId": "48e23158-1468-445d-a4cd-b5bd67bd3bde" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cosine similarity between 'The quick brown fox' and 'The quick brown dog': 0.90\n", + "Cosine similarity between 'The quick brown fox' and 'The fast brown fox': 0.86\n", + "Cosine similarity between 'The quick brown fox' and 'A completely different sentence': 0.58\n", + "Cosine similarity between 'The quick brown dog' and 'The fast brown fox': 0.84\n", + "Cosine similarity between 'The quick brown dog' and 'A completely different sentence': 0.66\n", + "Cosine similarity between 'The fast brown fox' and 'A completely different sentence': 0.62\n" + ] + } + ], + "source": [ + "import json\n", + "import numpy as np\n", + "import requests\n", + "from scipy.spatial.distance import cosine\n", + "\n", + "\n", + "def get_embedding_from_api(word, model='vicuna-7b-v1.5'):\n", + " url = 'http://127.0.0.1:8000/v1/embeddings'\n", + " headers = {'Content-Type': 'application/json'}\n", + " data = json.dumps({\n", + " 'model': model,\n", + " 'input': word\n", + " })\n", + "\n", + " response = requests.post(url, headers=headers, data=data)\n", + " if response.status_code == 200:\n", + " embedding = np.array(response.json()['data'][0]['embedding'])\n", + " return embedding\n", + " else:\n", + " print(f\"Error: {response.status_code} - {response.text}\")\n", + " return None\n", + "\n", + "\n", + "def cosine_similarity(vec1, vec2):\n", + " return 1 - cosine(vec1, vec2)\n", + "\n", + "\n", + "def print_cosine_similarity(embeddings, texts):\n", + " for i in range(len(texts)):\n", + " for j in range(i + 1, len(texts)):\n", + " sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]])\n", + " print(f\"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}\")\n", + "\n", + "\n", + "texts = [\n", + " 'The quick brown fox',\n", + " 'The quick brown dog',\n", + " 'The fast brown fox',\n", + " 'A completely different sentence'\n", + "]\n", + "\n", + "embeddings = {}\n", + "for text in texts:\n", + " embeddings[text] = get_embedding_from_api(text)\n", + "\n", + "print_cosine_similarity(embeddings, texts)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/playground/deepspeed_config_s2.json b/playground/deepspeed_config_s2.json new file mode 100644 index 0000000000000000000000000000000000000000..4ade15504251620b4034ac1cb13f38c645949bda --- /dev/null +++ b/playground/deepspeed_config_s2.json @@ -0,0 +1,15 @@ +{ + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "contiguous_gradients": true, + "overlap_comm": true + }, + "fp16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto" +} \ No newline at end of file diff --git a/playground/deepspeed_config_s3.json b/playground/deepspeed_config_s3.json new file mode 100644 index 0000000000000000000000000000000000000000..07f4b16a66f112fb485e2c327dda812e5d33723a --- /dev/null +++ b/playground/deepspeed_config_s3.json @@ -0,0 +1,32 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "stage3_max_live_parameters" : 1e9, + "stage3_max_reuse_distance" : 1e9, + "stage3_prefetch_bucket_size" : 5e8, + "stage3_param_persistence_threshold" : 1e6, + "sub_group_size" : 1e12, + "stage3_gather_16bit_weights_on_model_save": true + }, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto" +} \ No newline at end of file diff --git a/playground/test_embedding/README.md b/playground/test_embedding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..57ac73c59f4a5c30f6c7a429debab12b9e7a1d7f --- /dev/null +++ b/playground/test_embedding/README.md @@ -0,0 +1,15 @@ +## Machine Learning with Embeddings +You can use embeddings to +- Evaluate text similarity, see [test_sentence_similarity.py](test_sentence_similarity.py) +- Build your own classifier, see [test_classification.py](test_classification.py) +- Search relative texts, see [test_semantic_search.py](test_semantic_search.py) + +To these tests, you need to download the data [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews). You also need an OpenAI API key for comparison. + +Run with: +```bash +cd playground/test_embedding +python3 test_classification.py +``` + +The script will train classifiers based on `vicuna-7b`, `text-similarity-ada-001` and `text-embedding-ada-002` and report the accuracy of each classifier. diff --git a/playground/test_embedding/test_classification.py b/playground/test_embedding/test_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..393827bb47ebfad2d38d04d58aadc71f7ad7d407 --- /dev/null +++ b/playground/test_embedding/test_classification.py @@ -0,0 +1,83 @@ +import json +import os + +import numpy as np +import openai +import pandas as pd +import requests +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, accuracy_score + + +np.set_printoptions(threshold=10000) + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def create_embedding_data_frame(data_path, model, max_tokens=500): + df = pd.read_csv(data_path, index_col=0) + df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] + df = df.dropna() + df["combined"] = ( + "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() + ) + top_n = 1000 + df = df.sort_values("Time").tail(top_n * 2) + df.drop("Time", axis=1, inplace=True) + + df["n_tokens"] = df.combined.apply(lambda x: len(x)) + df = df[df.n_tokens <= max_tokens].tail(top_n) + df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) + return df + + +def train_random_forest(df): + X_train, X_test, y_train, y_test = train_test_split( + list(df.embedding.values), df.Score, test_size=0.2, random_state=42 + ) + + clf = RandomForestClassifier(n_estimators=100) + clf.fit(X_train, y_train) + preds = clf.predict(X_test) + + report = classification_report(y_test, preds) + accuracy = accuracy_score(y_test, preds) + return clf, accuracy, report + + +input_datapath = "amazon_fine_food_review.csv" +if not os.path.exists(input_datapath): + raise Exception( + f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" + ) + +df = create_embedding_data_frame(input_datapath, "vicuna-7b-v1.1") +clf, accuracy, report = train_random_forest(df) +print(f"Vicuna-7b-v1.1 accuracy:{accuracy}") +df = create_embedding_data_frame(input_datapath, "text-similarity-ada-001") +clf, accuracy, report = train_random_forest(df) +print(f"text-similarity-ada-001 accuracy:{accuracy}") +df = create_embedding_data_frame(input_datapath, "text-embedding-ada-002") +clf, accuracy, report = train_random_forest(df) +print(f"text-embedding-ada-002 accuracy:{accuracy}") diff --git a/playground/test_embedding/test_semantic_search.py b/playground/test_embedding/test_semantic_search.py new file mode 100644 index 0000000000000000000000000000000000000000..879b240b626d9bb87f739cfe749822a17777efab --- /dev/null +++ b/playground/test_embedding/test_semantic_search.py @@ -0,0 +1,99 @@ +import json +import os + +import numpy as np +import openai +import pandas as pd +import requests +from scipy.spatial.distance import cosine + + +def cosine_similarity(vec1, vec2): + try: + return 1 - cosine(vec1, vec2) + except: + print(vec1.shape, vec2.shape) + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def create_embedding_data_frame(data_path, model, max_tokens=500): + df = pd.read_csv(data_path, index_col=0) + df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] + df = df.dropna() + df["combined"] = ( + "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() + ) + top_n = 1000 + df = df.sort_values("Time").tail(top_n * 2) + df.drop("Time", axis=1, inplace=True) + + df["n_tokens"] = df.combined.apply(lambda x: len(x)) + df = df[df.n_tokens <= max_tokens].tail(top_n) + df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) + return df + + +def search_reviews(df, product_description, n=3, pprint=False, model="vicuna-7b-v1.1"): + product_embedding = get_embedding_from_api(product_description, model=model) + df["similarity"] = df.embedding.apply( + lambda x: cosine_similarity(x, product_embedding) + ) + + results = ( + df.sort_values("similarity", ascending=False) + .head(n) + .combined.str.replace("Title: ", "") + .str.replace("; Content:", ": ") + ) + if pprint: + for r in results: + print(r[:200]) + print() + return results + + +def print_model_search(input_path, model): + print(f"Model: {model}") + df = create_embedding_data_frame(input_path, model) + print("search: delicious beans") + results = search_reviews(df, "delicious beans", n=5, model=model) + print(results) + print("search: whole wheat pasta") + results = search_reviews(df, "whole wheat pasta", n=5, model=model) + print(results) + print("search: bad delivery") + results = search_reviews(df, "bad delivery", n=5, model=model) + print(results) + + +input_datapath = "amazon_fine_food_review.csv" +if not os.path.exists(input_datapath): + raise Exception( + f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" + ) + + +print_model_search(input_datapath, "vicuna-7b-v1.1") +print_model_search(input_datapath, "text-similarity-ada-001") +print_model_search(input_datapath, "text-embedding-ada-002") diff --git a/playground/test_embedding/test_sentence_similarity.py b/playground/test_embedding/test_sentence_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a8f6e5f5ac69b518019d12c8fc8068fd0a7232 --- /dev/null +++ b/playground/test_embedding/test_sentence_similarity.py @@ -0,0 +1,67 @@ +import json +import os + +import numpy as np +import openai +import requests +from scipy.spatial.distance import cosine + + +def get_embedding_from_api(word, model="vicuna-7b-v1.5"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def cosine_similarity(vec1, vec2): + return 1 - cosine(vec1, vec2) + + +def print_cosine_similarity(embeddings, texts): + for i in range(len(texts)): + for j in range(i + 1, len(texts)): + sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]]) + print(f"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}") + + +texts = [ + "The quick brown fox", + "The quick brown dog", + "The fast brown fox", + "A completely different sentence", +] + +embeddings = {} +for text in texts: + embeddings[text] = get_embedding_from_api(text) + +print("Vicuna-7B:") +print_cosine_similarity(embeddings, texts) + +for text in texts: + embeddings[text] = get_embedding_from_api(text, model="text-similarity-ada-001") + +print("text-similarity-ada-001:") +print_cosine_similarity(embeddings, texts) + +for text in texts: + embeddings[text] = get_embedding_from_api(text, model="text-embedding-ada-002") + +print("text-embedding-ada-002:") +print_cosine_similarity(embeddings, texts) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..fedd9e2dc1165695950e8050b76e4c139534c390 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "fschat" +version = "0.2.36" +description = "An open platform for training, serving, and evaluating large language model based chatbots." +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", + "prompt_toolkit>=3.0.0", "pydantic<3,>=2.0.0", "pydantic-settings", "psutil", "requests", "rich>=10.0.0", + "shortuuid", "tiktoken", "uvicorn", +] + +[project.optional-dependencies] +model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] +webui = ["gradio>=4.10"] +train = ["einops", "flash-attn>=2.0", "wandb"] +llm_judge = ["openai<1", "anthropic>=0.3", "ray"] +dev = ["black==23.3.0", "pylint==2.8.2"] + +[project.urls] +"Homepage" = "https://github.com/lm-sys/fastchat" +"Bug Tracker" = "https://github.com/lm-sys/fastchat/issues" + +[tool.setuptools.packages.find] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] + +[tool.wheel] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] diff --git a/scripts/build-api.sh b/scripts/build-api.sh new file mode 100644 index 0000000000000000000000000000000000000000..8198108e08b23c0dcd805bcf30769ea6e4e6fa5f --- /dev/null +++ b/scripts/build-api.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# A rather convenient script for spinning up models behind screens + + +# Variables +PROJECT_DIR="$(pwd)" +CONDA_ENV_NAME="fastchat" # + +MODEL_PATH="HuggingFaceH4/zephyr-7b-beta" #beta is better than the alpha version, base model w/o quantization +MODEL_PATH="lmsys/vicuna-7b-v1.5" + +API_HOST="0.0.0.0" +API_PORT_NUMBER=8000 + + +# init the screens +check_and_create_screen() { + local SCREENNAME="$1" + if screen -list | grep -q "$SCREENNAME"; then + echo "Screen session '$SCREENNAME' exists. Doing nothing." + else + echo "Screen session '$SCREENNAME' not found. Creating..." + screen -d -m -S "$SCREENNAME" + echo "created!" + fi +} + +# convenience function for sending commands to named screens +send_cmd() { + local SCREENNAME="$1" + local CMD="$2" + screen -DRRS $SCREENNAME -X stuff '$2 \r' +} + +# hardcoded names, for baby api +SCREENNAMES=( + "controller" + "api" + # Worker screens include the devices they are bound to, if 'd0' is only worker it has full GPU access + "worker-d0" + "worker-d1" +) + +for screen in "${SCREENNAMES[@]}"; do + check_and_create_screen "$screen" + sleep 0.1 + # also activate the conda compute environment for these + screen -DRRS "$screen" -X stuff "conda deactivate \r" + screen -DRRS "$screen" -X stuff "conda activate $CONDA_ENV_NAME \r" + +done + + +# Send Commmands on a per Screen Basis +screen -DRRS controller -X stuff "python3 -m fastchat.serve.controller \r" + +screen -DRRS worker-d0 -X stuff "CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --conv-template one_shot --limit-worker-concurrency 1 \r" +screen -DRRS worker-d1 -X stuff "CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --port 21003 --worker-address http://localhost:21003 --conv-template one_shot --limit-worker-concurrency 1 \r" + +screen -DRRS api -X stuff "python3 -m fastchat.serve.openai_api_server --host $API_HOST --port $API_PORT_NUMBER \r" diff --git a/scripts/test_readme_train.sh b/scripts/test_readme_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..1d1a011396c5298aed26f887418e54fd79a1ace4 --- /dev/null +++ b/scripts/test_readme_train.sh @@ -0,0 +1,24 @@ +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --data_path data/dummy_conversation.json \ + --bf16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True diff --git a/scripts/train_lora.sh b/scripts/train_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..d30caad41581244459996a31d814040cd0adb550 --- /dev/null +++ b/scripts/train_lora.sh @@ -0,0 +1,29 @@ +deepspeed fastchat/train/train_lora.py \ + --model_name_or_path lmsys/vicuna-7b-v1.5 \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --data_path $DATA_PATH \ + --output_dir ./checkpoints \ + --num_train_epochs 150 \ + --fp16 True \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "steps" \ + --eval_steps 100 \ + --save_strategy "steps" \ + --save_steps 200 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_strategy "steps" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --q_lora False \ + --deepspeed $PATH_TO_DEEPSPEED_CONFIG \ + --gradient_checkpointing True \ + --flash_attn False diff --git a/scripts/train_vicuna_13b.sh b/scripts/train_vicuna_13b.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6a843d3734e0a0127a9c54ad862f3992c10ac33 --- /dev/null +++ b/scripts/train_vicuna_13b.sh @@ -0,0 +1,26 @@ +torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path ~/model_weights/llama-13b \ + --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \ + --bf16 True \ + --output_dir output_vicuna_13b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 32 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "steps" \ + --eval_steps 1500 \ + --save_strategy "steps" \ + --save_steps 1500 \ + --save_total_limit 8 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.04 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap offload" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True + diff --git a/scripts/train_vicuna_7b.sh b/scripts/train_vicuna_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..8d1a45ac50934c25ee8d616e270a8805103fbaf2 --- /dev/null +++ b/scripts/train_vicuna_7b.sh @@ -0,0 +1,26 @@ +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path ~/model_weights/llama-7b \ + --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \ + --bf16 True \ + --output_dir output_vicuna_7b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 16 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "steps" \ + --eval_steps 1500 \ + --save_strategy "steps" \ + --save_steps 1500 \ + --save_total_limit 8 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.04 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True + diff --git a/scripts/upload_pypi.sh b/scripts/upload_pypi.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0da77ef2e3707caf1d5df00067f8d44f80c81b3 --- /dev/null +++ b/scripts/upload_pypi.sh @@ -0,0 +1,3 @@ +rm -rf dist +python3 -m build +python3 -m twine upload dist/* diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3d1c1e61cf0287b57512b0c6eb3f2dda9572b964 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,53 @@ +## Unit tests for FastChat + +### Test CLI Inference + +``` +python3 test_cli.py +``` + +### Test OpenAI API Server + +``` +python3 launch_openai_api_test_server.py +``` + +``` +python3 test_openai_api.py +``` + +### Test GUI Serving + +``` +python3 -m fastchat.serve.controller +``` + +``` +CUDA_VISIBLE_DEVICES=0,1 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/koala-13b --num-gpus 2 --port 30000 --worker http://localhost:30000 +CUDA_VISIBLE_DEVICES=2,3 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/alpaca-13b --num-gpus 2 --port 30002 --worker http://localhost:30002 +CUDA_VISIBLE_DEVICES=4,5 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/vicuna-13b --port 30004 --worker http://localhost:30004 --num-gpus 2 +CUDA_VISIBLE_DEVICES=6,7 python3 -m fastchat.serve.model_worker --model-path OpenAssistant/oasst-sft-1-pythia-12b --port 30006 --worker http://localhost:30006 --num-gpus 2 + +CUDA_VISIBLE_DEVICES=0,1 python3 -m fastchat.serve.model_worker --model-path StabilityAI/stablelm-tuned-alpha-7b --num-gpus 2 --port 30000 --worker http://localhost:30000 +CUDA_VISIBLE_DEVICES=2,3 python3 -m fastchat.serve.model_worker --model-path databricks/dolly-v2-12b --num-gpus 2 --port 30002 --worker http://localhost:30002 +CUDA_VISIBLE_DEVICES=4 python3 -m fastchat.serve.model_worker --model-path THUDM/chatglm-6b --port 30004 --worker http://localhost:30004 +CUDA_VISIBLE_DEVICES=5 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --port 30005 --worker http://localhost:30005 +CUDA_VISIBLE_DEVICES=6 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/baize-7b --port 30006 --worker http://localhost:30006 +CUDA_VISIBLE_DEVICES=7 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth --port 30007 --worker http://localhost:30007 +``` + +``` +python3 -m fastchat.serve.gradio_web_server_multi +``` + +### Test Peft Serving + +``` +python3 -m fastchat.serve.controller +``` + +``` +PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \ + --model-path SurfaceData/dummy_pythia160m_lora16_peft_chat \ + --model-path SurfaceData/dummy_pythia160m_lora8_peft_chat +``` diff --git a/tests/killall_python.sh b/tests/killall_python.sh new file mode 100644 index 0000000000000000000000000000000000000000..52eb874ac63799eaa33dd61acf9d8f19ddf768ef --- /dev/null +++ b/tests/killall_python.sh @@ -0,0 +1 @@ +kill -9 $(ps aux | grep 'python' | grep 'fastchat' | grep -v 'grep' | awk '{print $2}') diff --git a/tests/launch_openai_api_test_server.py b/tests/launch_openai_api_test_server.py new file mode 100644 index 0000000000000000000000000000000000000000..e3461a935fe2f16ef2a6377b0424134ea0948043 --- /dev/null +++ b/tests/launch_openai_api_test_server.py @@ -0,0 +1,49 @@ +""" +Launch an OpenAI API server with multiple model workers. +""" +import os +import argparse + + +def launch_process(cmd): + os.popen(cmd) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--multimodal", action="store_true", default=False) + args = parser.parse_args() + + launch_process("python3 -m fastchat.serve.controller") + launch_process("python3 -m fastchat.serve.openai_api_server") + + if args.multimodal: + models = [ + ("liuhaotian/llava-v1.5-7b", "sglang_worker"), + ] + else: + models = [ + ("lmsys/vicuna-7b-v1.5", "model_worker"), + ("lmsys/fastchat-t5-3b-v1.0", "model_worker"), + ("THUDM/chatglm-6b", "model_worker"), + ("mosaicml/mpt-7b-chat", "model_worker"), + ("meta-llama/Llama-2-7b-chat-hf", "vllm_worker"), + ] + + for i, (model_path, worker_name) in enumerate(models): + cmd = ( + f"CUDA_VISIBLE_DEVICES={i} python3 -m fastchat.serve.{worker_name} " + f"--model-path {model_path} --port {40000+i} " + f"--worker-address http://localhost:{40000+i} " + ) + + if "llava" in model_path.lower(): + cmd += f"--tokenizer-path llava-hf/llava-1.5-7b-hf" + + if worker_name == "vllm_worker": + cmd += "--tokenizer hf-internal-testing/llama-tokenizer" + + launch_process(cmd) + + while True: + pass diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..113e497a48062bc68039357e97194fdc8d5853d9 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,88 @@ +"""Test command line interface for model inference.""" +import argparse +import os + +from fastchat.utils import run_cmd + + +def test_single_gpu(): + models = [ + "lmsys/vicuna-7b-v1.5", + "lmsys/longchat-7b-16k", + "lmsys/fastchat-t5-3b-v1.0", + "meta-llama/Llama-2-7b-chat-hf", + "THUDM/chatglm-6b", + "THUDM/chatglm2-6b", + "mosaicml/mpt-7b-chat", + "tiiuae/falcon-7b-instruct", + "~/model_weights/alpaca-7b", + "~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth", + ] + + for model_path in models: + if "model_weights" in model_path and not os.path.exists( + os.path.expanduser(model_path) + ): + continue + cmd = ( + f"python3 -m fastchat.serve.cli --model-path {model_path} " + f"--style programmatic < test_cli_inputs.txt" + ) + ret = run_cmd(cmd) + if ret != 0: + return + + print("") + + +def test_multi_gpu(): + models = [ + "lmsys/vicuna-13b-v1.3", + ] + + for model_path in models: + cmd = ( + f"python3 -m fastchat.serve.cli --model-path {model_path} " + f"--style programmatic --num-gpus 2 --max-gpu-memory 14Gib < test_cli_inputs.txt" + ) + ret = run_cmd(cmd) + if ret != 0: + return + print("") + + +def test_8bit(): + models = [ + "lmsys/vicuna-13b-v1.3", + ] + + for model_path in models: + cmd = ( + f"python3 -m fastchat.serve.cli --model-path {model_path} " + f"--style programmatic --load-8bit < test_cli_inputs.txt" + ) + ret = run_cmd(cmd) + if ret != 0: + return + print("") + + +def test_hf_api(): + models = [ + "lmsys/vicuna-7b-v1.5", + "lmsys/fastchat-t5-3b-v1.0", + ] + + for model_path in models: + cmd = f"python3 -m fastchat.serve.huggingface_api --model-path {model_path}" + ret = run_cmd(cmd) + if ret != 0: + return + print("") + + +if __name__ == "__main__": + test_single_gpu() + test_multi_gpu() + test_8bit() + test_hf_api() diff --git a/tests/test_cli_inputs.txt b/tests/test_cli_inputs.txt new file mode 100644 index 0000000000000000000000000000000000000000..df79f87e114662266a23bda9e6803271c7fdfa7a --- /dev/null +++ b/tests/test_cli_inputs.txt @@ -0,0 +1,4 @@ +Who are you? __END_OF_A_MESSAGE_47582648__ +Three tips for staying healthy. __END_OF_A_MESSAGE_47582648__ +One more tip. __END_OF_A_MESSAGE_47582648__ +!!exit __END_OF_A_MESSAGE_47582648__ diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53b1bfe7c49b7c25ca87ef180d80cc22b08c6448 --- /dev/null +++ b/tests/test_image_utils.py @@ -0,0 +1,131 @@ +""" +Usage: +python3 -m unittest tests.test_image_utils +""" + +import base64 +from io import BytesIO +import os +import unittest + +import numpy as np +from PIL import Image + +from fastchat.utils import ( + resize_image_and_return_image_in_bytes, + image_moderation_filter, +) +from fastchat.conversation import get_conv_template + + +def check_byte_size_in_mb(image_base64_str): + return len(image_base64_str) / 1024 / 1024 + + +def generate_random_image(target_size_mb, image_format="PNG"): + # Convert target size from MB to bytes + target_size_bytes = target_size_mb * 1024 * 1024 + + # Estimate dimensions + dimension = int((target_size_bytes / 3) ** 0.5) + + # Generate random pixel data + pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8) + + # Create an image from the pixel data + img = Image.fromarray(pixel_data) + + # Save image to a temporary file + temp_filename = "temp_image." + image_format.lower() + img.save(temp_filename, format=image_format) + + # Check the file size and adjust quality if needed + while os.path.getsize(temp_filename) < target_size_bytes: + # Increase dimensions or change compression quality + dimension += 1 + pixel_data = np.random.randint( + 0, 256, (dimension, dimension, 3), dtype=np.uint8 + ) + img = Image.fromarray(pixel_data) + img.save(temp_filename, format=image_format) + + return img + + +class DontResizeIfLessThanMaxTest(unittest.TestCase): + def test_dont_resize_if_less_than_max(self): + max_image_size = 5 + initial_size_mb = 0.1 # Initial image size + img = generate_random_image(initial_size_mb) + + image_bytes = BytesIO() + img.save(image_bytes, format="PNG") # Save the image as JPEG + previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) + + image_bytes = resize_image_and_return_image_in_bytes( + img, max_image_size_mb=max_image_size + ) + new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) + + self.assertEqual(previous_image_size, new_image_size) + + +class ResizeLargeImageForModerationEndpoint(unittest.TestCase): + def test_resize_large_image_and_send_to_moderation_filter(self): + initial_size_mb = 6 # Initial image size which we know is greater than what the endpoint can take + img = generate_random_image(initial_size_mb) + + nsfw_flag, csam_flag = image_moderation_filter(img) + self.assertFalse(nsfw_flag) + self.assertFalse(nsfw_flag) + + +class DontResizeIfMaxImageSizeIsNone(unittest.TestCase): + def test_dont_resize_if_max_image_size_is_none(self): + initial_size_mb = 0.2 # Initial image size + img = generate_random_image(initial_size_mb) + + image_bytes = BytesIO() + img.save(image_bytes, format="PNG") # Save the image as JPEG + previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) + + image_bytes = resize_image_and_return_image_in_bytes( + img, max_image_size_mb=None + ) + new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) + + self.assertEqual(previous_image_size, new_image_size) + + +class OpenAIConversationDontResizeImage(unittest.TestCase): + def test(self): + conv = get_conv_template("chatgpt") + initial_size_mb = 0.2 # Initial image size + img = generate_random_image(initial_size_mb) + image_bytes = BytesIO() + img.save(image_bytes, format="PNG") # Save the image as JPEG + previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) + + resized_img = conv.convert_image_to_base64(img) + resized_img_bytes = base64.b64decode(resized_img) + new_image_size = check_byte_size_in_mb(resized_img_bytes) + + self.assertEqual(previous_image_size, new_image_size) + + +class ClaudeConversationResizesCorrectly(unittest.TestCase): + def test(self): + conv = get_conv_template("claude-3-haiku-20240307") + initial_size_mb = 5 # Initial image size + img = generate_random_image(initial_size_mb) + image_bytes = BytesIO() + img.save(image_bytes, format="PNG") # Save the image as JPEG + previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) + + resized_img = conv.convert_image_to_base64(img) + new_base64_image_size = check_byte_size_in_mb(resized_img) + new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img)) + + self.assertLess(new_image_bytes_size, previous_image_size) + self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb) + self.assertLessEqual(new_base64_image_size, 5) diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7879bb5e499b080a9e50c5f5db570d7378f0dc --- /dev/null +++ b/tests/test_openai_api.py @@ -0,0 +1,149 @@ +""" +Test the OpenAI compatible server + +Launch: +python3 launch_openai_api_test_server.py +""" +import warnings + +import openai +from fastchat.utils import run_cmd + + +openai.api_key = "EMPTY" # Not support yet +openai.base_url = "http://localhost:8000/v1/" + + +def test_list_models(): + model_list = openai.models.list() + names = [x.id for x in model_list.data] + return names + + +def test_completion(model, logprob): + prompt = "Once upon a time" + completion = openai.completions.create( + model=model, + prompt=prompt, + logprobs=logprob, + max_tokens=64, + temperature=0, + ) + + print(f"full text: {prompt + completion.choices[0].text}", flush=True) + if completion.choices[0].logprobs is not None: + print( + f"logprobs: {completion.choices[0].logprobs.token_logprobs[:10]}", + flush=True, + ) + + +def test_completion_stream(model): + prompt = "Once upon a time" + res = openai.completions.create( + model=model, + prompt=prompt, + max_tokens=64, + stream=True, + temperature=0, + ) + print(prompt, end="") + for chunk in res: + content = chunk.choices[0].text + print(content, end="", flush=True) + print() + + +def test_embedding(model): + embedding = openai.embeddings.create(model=model, input="Hello world!") + print(f"embedding len: {len(embedding.data[0].embedding)}") + print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}") + + +def test_chat_completion(model): + completion = openai.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + temperature=0, + ) + print(completion.choices[0].message.content) + + +def test_chat_completion_stream(model): + messages = [{"role": "user", "content": "Hello! What is your name?"}] + res = openai.chat.completions.create( + model=model, messages=messages, stream=True, temperature=0 + ) + for chunk in res: + try: + content = chunk.choices[0].delta.content + if content is None: + content = "" + except Exception as e: + content = chunk.choices[0].delta.get("content", "") + print(content, end="", flush=True) + print() + + +def test_openai_curl(): + run_cmd("curl http://localhost:8000/v1/models") + + run_cmd( + """ +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "messages": [{"role": "user", "content": "Hello! What is your name?"}] + }' +""" + ) + + run_cmd( + """ +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "prompt": "Once upon a time", + "max_tokens": 41, + "temperature": 0.5 + }' +""" + ) + + run_cmd( + """ +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "input": "Hello world!" + }' +""" + ) + + +if __name__ == "__main__": + models = test_list_models() + print(f"models: {models}") + + for model in models: + print(f"===== Test {model} ======") + + if model in ["fastchat-t5-3b-v1.0"]: + logprob = None + else: + logprob = 1 + + test_completion(model, logprob) + test_completion_stream(model) + test_chat_completion(model) + test_chat_completion_stream(model) + try: + test_embedding(model) + except openai.APIError as e: + print(f"Embedding error: {e}") + + print("===== Test curl =====") + test_openai_curl() diff --git a/tests/test_openai_langchain.py b/tests/test_openai_langchain.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c07fcf6e00b4c50629b4aac1e1ccc31974a644 --- /dev/null +++ b/tests/test_openai_langchain.py @@ -0,0 +1,39 @@ +# Usage: +# python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --model-names gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002 +# export OPENAI_API_BASE=http://localhost:8000/v1 +# export OPENAI_API_KEY=EMPTY +# wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt + +import os + +from langchain.chat_models import ChatOpenAI +from langchain.document_loaders import TextLoader +from langchain.embeddings import OpenAIEmbeddings +from langchain.indexes import VectorstoreIndexCreator + + +def test_chain(): + embedding = OpenAIEmbeddings(model="text-embedding-ada-002") + loader = TextLoader("state_of_the_union.txt") + index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) + + llm = ChatOpenAI(model="gpt-3.5-turbo") + + questions = [ + "Who is the speaker", + "What did the president say about Ketanji Brown Jackson", + "What are the threats to America", + "Who are mentioned in the speech", + "Who is the vice president", + "How many projects were announced", + ] + + for query in questions: + print("Query:", query) + print("Answer:", index.query(query, llm=llm)) + + +if __name__ == "__main__": + os.environ["OPENAI_API_BASE"] = "http://localhost:8000/v1" + os.environ["OPENAI_API_KEY"] = "empty" + test_chain() diff --git a/tests/test_openai_vision_api.py b/tests/test_openai_vision_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2f089c418556fc68e3a791c62727fd945c8b9be9 --- /dev/null +++ b/tests/test_openai_vision_api.py @@ -0,0 +1,162 @@ +""" +Test the OpenAI compatible server + +Launch: +python3 launch_openai_api_test_server.py --multimodal +""" + +import openai + +from fastchat.utils import run_cmd + +openai.api_key = "EMPTY" # Not support yet +openai.base_url = "http://localhost:8000/v1/" + + +def encode_image(image): + import base64 + from io import BytesIO + import requests + + from PIL import Image + + if image.startswith("http://") or image.startswith("https://"): + response = requests.get(image) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image).convert("RGB") + + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + + return img_b64_str + + +def test_list_models(): + model_list = openai.models.list() + names = [x.id for x in model_list.data] + return names + + +def test_chat_completion(model): + image_url = "https://picsum.photos/seed/picsum/1024/1024" + base64_image_url = f"data:image/jpeg;base64,{encode_image(image_url)}" + + # No Image + completion = openai.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about alpacas."}, + ], + } + ], + temperature=0, + ) + print(completion.choices[0].message.content) + print("=" * 25) + + # Image using url link + completion = openai.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + temperature=0, + ) + print(completion.choices[0].message.content) + print("=" * 25) + + # Image using base64 image url + completion = openai.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": base64_image_url}}, + ], + } + ], + temperature=0, + ) + print(completion.choices[0].message.content) + print("=" * 25) + + +def test_chat_completion_stream(model): + image_url = "https://picsum.photos/seed/picsum/1024/1024" + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ] + res = openai.chat.completions.create( + model=model, messages=messages, stream=True, temperature=0 + ) + for chunk in res: + try: + content = chunk.choices[0].delta.content + if content is None: + content = "" + except Exception as e: + content = chunk.choices[0].delta.get("content", "") + print(content, end="", flush=True) + print() + + +def test_openai_curl(): + run_cmd( + """curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llava-v1.5-7b", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What’s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://picsum.photos/seed/picsum/1024/1024" + } + } + ] + } + ], + "max_tokens": 300 + }' + """ + ) + + print() + + +if __name__ == "__main__": + models = test_list_models() + print(f"models: {models}") + + for model in models: + print(f"===== Test {model} ======") + test_chat_completion(model) + test_chat_completion_stream(model) + test_openai_curl()