File size: 1,829 Bytes
7e9619e 3762b09 7e9619e 16515a6 7e9619e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
---
license: apache-2.0
metrics:
- accuracy
pipeline_tag: text-classification
tags:
- LSGAttention
language:
- fr
- it
- de
- es
- en
inference: true
---
# notdiamond-4k-0001
notdiamond-4k-0001 supports **4096 input sequence length**. This model is an extention of [notdiamond-0001](https://huggingface.co/notdiamond/notdiamond-0001) which originally supported sequence length 512.
**LSG atttention** is used to adapt existing pre-trained model to efficiently extrapolate to 4046 sequence length with no additional training.
notdiamond-0001 automatically determines whether to send queries to GPT-3.5 or GPT-4, depending on which model is best-suited for your task. notdiamond-0001 was trained on hundreds of thousands of data points from robust, cross-domain evaluation benchmarks.
The notdiamond-0001 router model is a classifier and will return a label for either GPT-3.5 or GPT-4.
Inference:
``` python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# input format
query = "Can you write a function that counts from 1 to 10?"
formatted_prompt = f"""Determine whether the following query should be sent to GPT-3.5 or GPT-4.
Query:
{query}"""
tokenizer = AutoTokenizer.from_pretrained("notdiamond/notdiamond-0001")
model = AutoModelForSequenceClassification.from_pretrained("notdiamond/notdiamond-0001")
inputs = tokenizer(formatted_prompt,
truncation=True, max_length=4096, return_tensors="pt")
logits = model(**inputs).logits
model_id = logits.argmax().item()
id2label = {0: 'gpt-3.5', 1: 'gpt-4'}
model_to_call = id2label[model_id]
```
You can also access their free [API](https://www.notdiamond.ai/notdiamond-0001) and the official website : [documentation](https://notdiamond.readme.io/docs/introduction). |