notdiamond-4k-0001 / README.md
arshadshk's picture
Update README.md
84021b0 verified
|
raw
history blame
No virus
1.98 kB
metadata
license: apache-2.0
metrics:
  - accuracy
pipeline_tag: text-classification
tags:
  - LSGAttention
language:
  - en
  - hn
  - fr
  - it
  - de
  - es
widget:
  - text: write a python function that counts from 1 to 10?
  - text: If tan A = 3/4, prove that Sin A Cos A = 12/25. solve step by step.
inference: true

notdiamond-4k-0001

notdiamond-4k-0001 supports 4096 input sequence length. This model is an extention of notdiamond-0001 which originally supported sequence length 512.
LSG atttention is used to adapt existing pre-trained model to efficiently extrapolate to 4046 sequence length with no additional training.

notdiamond-0001 automatically determines whether to send queries to GPT-3.5 or GPT-4, depending on which model is best-suited for your task. notdiamond-0001 was trained on hundreds of thousands of data points from robust, cross-domain evaluation benchmarks. The notdiamond-0001 router model is a classifier and will return a label for either GPT-3.5 or GPT-4.

Inference:

    from transformers import AutoTokenizer, AutoModelForSequenceClassification

    # input format
    query = "Can you write a function that counts from 1 to 10?"
    formatted_prompt = f"""Determine whether the following query should be sent to GPT-3.5 or GPT-4.
        Query:
        {query}"""

    tokenizer = AutoTokenizer.from_pretrained("notdiamond/notdiamond-0001")
    model = AutoModelForSequenceClassification.from_pretrained("notdiamond/notdiamond-0001")

    inputs = tokenizer(formatted_prompt,
      truncation=True, max_length=4096, return_tensors="pt")
    logits = model(**inputs).logits

    model_id = logits.argmax().item()
    id2label = {0: 'gpt-3.5', 1: 'gpt-4'}
    model_to_call = id2label[model_id]

You can also access their free API and the official website : documentation.