Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- OpenAssistant/oasst1
|
5 |
+
language:
|
6 |
+
- en
|
7 |
+
pipeline_tag: conversational
|
8 |
+
---
|
9 |
+
|
10 |
+
# Ava small
|
11 |
+
|
12 |
+
## Training Details
|
13 |
+
|
14 |
+
The fine-tuning process for this model involved several key parameters and settings:
|
15 |
+
|
16 |
+
- **Base Model:** GPT-2
|
17 |
+
- **Dataset:** Open Assistant's [oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1) dataset
|
18 |
+
- **Learning Rate:** 1e-3
|
19 |
+
- **Epochs:** 10
|
20 |
+
- **Hardware:** GPU P100
|
21 |
+
|
22 |
+
The model was trained on a GPU P100 to expedite the training process and take advantage of the hardware's parallel processing capabilities. The learning rate was set to 1e-3 to balance the trade-off between fast convergence and avoiding overshooting.
|
23 |
+
|
24 |
+
## Model Performance
|
25 |
+
|
26 |
+
After 10 epochs of training, the model achieved improved performance in generating coherent and contextually relevant responses in conversations. However, it's important to note that the model's responses might still exhibit occasional inaccuracies or inconsistencies.
|
27 |
+
|
28 |
+
## Custom Tokens and Contextualization
|
29 |
+
|
30 |
+
To facilitate structured conversations and improve response generation, the following custom tokens were added:
|
31 |
+
|
32 |
+
- `<startoftext>`: Marks the beginning of a conversation prompt.
|
33 |
+
- `<endoftext>`: Marks the end of a conversation prompt.
|
34 |
+
- `<ava>`: Denotes the beginning of responses generated by the AI assistant.
|
35 |
+
- `</ava>`: Denotes the end of AI-generated responses.
|
36 |
+
- `<user>`: Denotes the beginning of user input in the conversation.
|
37 |
+
- `</user>`: Denotes the end of user input.
|
38 |
+
|
39 |
+
Here is example of prompting:
|
40 |
+
|
41 |
+
```
|
42 |
+
<startoftext><user>Hello</user><ava>Hello there, How can i assist you today?</ava></endoftext>
|
43 |
+
```
|
44 |
+
|
45 |
+
## Use Cases and Applications
|
46 |
+
|
47 |
+
Given its training on dialogues and conversations, this fine-tuned model is particularly well-suited for the following use cases:
|
48 |
+
|
49 |
+
- Dynamic and engaging conversations with users in chatbots or virtual assistants.
|
50 |
+
- Providing personalized information and assistance across diverse domains.
|
51 |
+
- Generating contextually relevant and creative responses to user inputs.
|
52 |
+
- Enhancing the user experience and interaction quality.
|
53 |
+
|
54 |
+
## Inference script
|
55 |
+
|
56 |
+
```python
|
57 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
58 |
+
|
59 |
+
def inference(text, model, tokenizer):
|
60 |
+
data = tokenizer.encode(f'<startoftext><user>{text}</user><ava>', return_tensors='pt')
|
61 |
+
input_ids = data.to(device)
|
62 |
+
|
63 |
+
output = model.generate(
|
64 |
+
input_ids=input_ids,
|
65 |
+
temperature=0.8,
|
66 |
+
max_length=100,
|
67 |
+
top_k=50,
|
68 |
+
top_p=0.95,
|
69 |
+
repetition_penalty=1.2,
|
70 |
+
num_return_sequences=1,
|
71 |
+
)
|
72 |
+
|
73 |
+
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
|
74 |
+
ava_response = decoded_output.split('<ava>')[1].split('</ava>')[0]
|
75 |
+
clean_response = ava_response.split('.')[0].strip()
|
76 |
+
|
77 |
+
return clean_response
|
78 |
+
|
79 |
+
model_name = 'Kuduxaaa/ava-small'
|
80 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
81 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
82 |
+
|
83 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
84 |
+
model.to(device)
|
85 |
+
|
86 |
+
user_input = "What's the weather like today?"
|
87 |
+
response = inference(user_input, model, tokenizer)
|
88 |
+
|
89 |
+
print('Ava: ', response)
|
90 |
+
```
|