File size: 2,214 Bytes
47cbe93
 
284ece4
b8ba1e0
859873c
b8ba1e0
 
 
deea6e3
b8ba1e0
 
 
d861a5c
b8ba1e0
deea6e3
b8ba1e0
47cbe93
060d065
 
 
0266d0b
b8ba1e0
 
 
0266d0b
 
b8ba1e0
 
47cbe93
b8ba1e0
 
47cbe93
b8ba1e0
47cbe93
b8ba1e0
 
 
 
 
 
47cbe93
b8ba1e0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import pipeline, Conversation
import gradio as gr
import torch
#https://huggingface.co/TheBloke/starchat-beta-GPTQ

from transformers import AutoTokenizer, pipeline, logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import argparse

model_name_or_path = "TheBloke/starchat-beta-GPTQ"
# Or to load it locally, pass the local download path
# model_name_or_path = "/path/to/models/The_Bloke_starchat-beta-GPTQ"

use_triton = False

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

device=("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

model = AutoGPTQForCausalLM.from_quantized(model_name_or_path,
        use_safetensors=True,
        #device="cuda:0",
        device=device,
        #device_map="auto",
        use_triton=use_triton,
        quantize_config=None)

# Prevent printing spurious transformers error when using pipeline with AutoGPTQ
logging.set_verbosity(logging.CRITICAL)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

prompt_template = "<|system|>\n<|end|>\n<|user|>\n{query}<|end|>\n<|assistant|>"
prompt = prompt_template.format(query="How do I sort a list in Python?")
# We use a special <|end|> token with ID 49155 to denote ends of a turn
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, eos_token_id=49155)
# You can sort a list in Python by using the sort() method. Here's an example:\n\n```\nnumbers = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]\nnumbers.sort()\nprint(numbers)\n```\n\nThis will sort the list in place and print the sorted list.
print(outputs[0]['generated_text'])

#message_list = []
#response_list = []

#def vanilla_chatbot(message, history):
#    conversation = Conversation(text=message, past_user_inputs=message_list, generated_responses=response_list)
#    conversation = chatbot(conversation)

#    return conversation.generated_responses[-1]

#demo_chatbot = gr.ChatInterface(vanilla_chatbot, title="Vanilla Chatbot", description="Enter text to start chatting.")

#demo_chatbot.launch()