Text Generation
Transformers
PyTorch
Japanese
llama
text-generation-inference
Inference Endpoints
openorca_stx / README.md
ptrdvn's picture
Update README.md
7b833f2
|
raw
history blame
6.98 kB
---
license: llama2
datasets:
- snow_simplified_japanese_corpus
- khalidalt/tydiqa-goldp
- csebuetnlp/xlsum
language:
- ja
---
# About
This model is Lightblue's QLoRA finetune of OpenOrca's [Open-Orca/OpenOrcaxOpenChat-Preview2-13B](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B) model on Japanese fine-tuning datasets.
This model specialises on answering **Closed Question Answering** in Japanese. Input a piece of reference text, ask a question, and see the model answer based on the reference text.
We trained on equal samples of the following three datasets:
* [SNOW](https://huggingface.co/datasets/snow_simplified_japanese_corpus)
* [TyDiQA (Ja)](https://huggingface.co/datasets/khalidalt/tydiqa-goldp)
* [XLSUM (Ja)](https://huggingface.co/datasets/csebuetnlp/xlsum)
which resulted in a dataset of 13,167 samples total.
These three datasets were chosen as they represent three distinct fine-tuning tasks (Text simplification, question answering, and text summarization, respectively) which we hypothesize can help to improve the language models suitability for dealing with Japanese data.
These three datasets make up the model name: STX.
With these datasets, we achieve the following scores on the JGLUE benchmark:
| Model Name | Open-Orca/OpenOrcaxOpenChat-Preview2-13B | lightblue/openorca_stx |
|------------------------|------------------------------------------|------------------------|
| jsquad-1.1-0.3 | 0.692 | 0.836 |
| jcommonsenseqa-1.1-0.3 | 0.831 | 0.782 |
| jnli-1.1-0.3 | 0.504 | 0.48 |
| marc_ja-1.1-0.3 | 0.936 | 0.959 |
Our model achieves much better results on the question answering benchmark (JSQuAD) than the base checkpoint without monstrous degradation of performance on multi-choice question benchmarks (JCommonSense, JNLI, MARC-Ja) purely through QLoRA training.
This shows the potential for applying strong language models such as [Open-Orca/OpenOrcaxOpenChat-Preview2-13B](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B) to minimal QLoRA fine-tuning using Japanese fine-tuning datasets to achieve better results at narrow NLP tasks.
# How to use
```python
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_dir = "lightblue/openorca_stx"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch.bfloat16, device_map='auto',
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
def do_closed_qa(context, question):
return context + "\n\n" + question
test_article = """ใ€€ใƒขใƒŽใƒžใƒใฎใƒฌใƒ‘ใƒผใƒˆใƒชใƒผใซใ€Œใƒชใƒผใƒใƒปใƒžใ‚คใ‚ฑใƒซ้ธๆ‰‹ใ€ใŒใ‚ใ‚‹ใƒฌใ‚คใ‚ถใƒผใƒฉใƒขใƒณRGใ•ใ‚“ใ€‚ๆœฌไบบๅ…ฌ่ชใฎใƒขใƒŽใƒžใƒใงใ™ใŒใ€ใƒฉใ‚ฐใƒ“ใƒผใƒ•ใ‚กใƒณใฎๅๅฟœใซๅฐ‘ใ—้ฉšใ„ใŸใใ†ใงใ™ใ€‚
ใ€€ใƒชใƒผใƒใƒปใƒžใ‚คใ‚ฑใƒซ้ธๆ‰‹ใฎใƒขใƒŽใƒžใƒใฏใ€ไฝ•ใŒใใฃใ‹ใ‘ใงใ™ใ‹ใ€‚
ใ€Œ2015ๅนดใฎใƒฏใƒผใƒซใƒ‰ใ‚ซใƒƒใƒ—๏ผˆWๆฏ๏ผ‰ใ‚คใƒณใ‚ฐใƒฉใƒณใƒ‰ๅคงไผšใงๆ—ฅๆœฌใŒๅ—ใ‚ขใƒ•ใƒชใ‚ซใ‚’ๅ€’ใ—ใŸๆฌกใฎๆ—ฅใŒใ€ไบฌ้ƒฝใงใฎ็•ช็ต„ใƒญใ‚ฑใงใ—ใŸใ€‚ๅฝ“ๆ™‚ใฏใ€ใ‚ขใƒƒใƒ—ใƒซใฎๅ…ฑๅŒๅ‰ตๆฅญ่€…ใ‚นใƒ†ใ‚ฃใƒผใƒ–ใƒปใ‚ธใƒงใƒ–ใ‚บใฎใƒขใƒŽใƒžใƒใฐใ‹ใ‚Šใงใ—ใŸใŒใ€ไธ€็ท’ใซใƒญใ‚ฑใ‚’ใ—ใฆใ„ใŸใ‚ธใƒฃใƒณใ‚ฐใƒซใƒใ‚ฑใƒƒใƒˆใ‹ใ‚‰ใ€Žใƒชใƒผใƒใƒปใƒžใ‚คใ‚ฑใƒซใซไผผใฆใพใ™ใ‚ˆใ€‚ใ‚ธใƒงใƒ–ใ‚บใฎใพใพใ€ใ„ใ‘ใ‚‹ใ‚“ใ˜ใ‚ƒใชใ„ใงใ™ใ‹๏ผŸใ€ใจ่จ€ใ‚ใ‚ŒใŸใฎใŒๅง‹ใพใ‚Šใงใ™ใ€
ใ€ŒใŸใ ใ€ใฟใ‚“ใช็Ÿฅ่ญ˜ใŒใชใ„ใ€‚ใƒฉใ‚ฐใƒ“ใƒผใ‚ทใƒงใƒƒใƒ—ใ‚’ๆŽขใ—ใ€ๆ—ฅๆœฌไปฃ่กจใฎใƒฆใƒ‹ใƒ›ใƒผใƒ ใŒๅฃฒใ‚Šๅˆ‡ใ‚Œใ ใฃใŸใฎใงใ€่ตคใฃใฝใ„ใƒฆใƒ‹ใƒ›ใƒผใƒ ใจใƒ”ใƒใƒ”ใƒใฎ็Ÿญใƒ‘ใƒณใ‚’ใฏใ„ใฆใ€‚ใจใ‚Šใ‚ใˆใšSNSใงใ€Žใƒชใƒผใƒใƒปใƒžใ‚คใ‚ฑใƒซใงใ™ใ€ใฃใฆใ„ใฃใฑใ„ๅ†™็œŸใ‚’่ผ‰ใ›ใพใ—ใŸใ€
ใ€Œใ™ใ‚‹ใจใ€ใใ‚Œใ‚’่ฆ‹ใŸใƒชใƒผใƒใ•ใ‚“ๆœฌไบบใ‹ใ‚‰DM๏ผˆใƒ€ใ‚คใƒฌใ‚ฏใƒˆใƒกใƒƒใ‚ปใƒผใ‚ธ๏ผ‰ใŒๅฑŠใใพใ—ใŸใ€‚ใ€ŽใƒขใƒŽใƒžใƒใ‚ใ‚ŠใŒใจใ†ใ”ใ–ใ„ใพใ™ใ€‚ใ‚‚ใ—ใƒขใƒŽใƒžใƒใ‚’ใ™ใ‚‹ใชใ‚‰ใ€ๅƒ•ใฎใƒฆใƒ‹ใƒ›ใƒผใƒ ใ‚’้€ใ‚Šใพใ™ใฎใง็€ใฆใใ ใ•ใ„ใ€ใจใ€‚WๆฏๅพŒใซใƒฆใƒ‹ใƒ›ใƒผใƒ 2็€ใจใƒ‘ใƒณใƒ„ใ‚„ใ‚ฝใƒƒใ‚ฏใ‚นใชใฉใ‚’ใปใ‚“ใพใซ้€ใฃใฆใใฆใใ‚Œใพใ—ใŸใ€‚ไปŠ็€ใฆใ„ใ‚‹ใฎใŒใใ‚Œใงใ™ใ€
ใ“ใ‚Œใพใงใ€ๆ•ฐใ€…ใฎ่‘—ๅไบบใ‚’ใƒขใƒŽใƒžใƒใ—ใฆใ“ใ‚‰ใ‚Œใพใ—ใŸใ€‚ใƒชใƒผใƒ้ธๆ‰‹ใฎใƒใ‚ฟใฎๅ้Ÿฟใฏใ„ใ‹ใŒใงใ—ใŸใ‹ใ€‚
ใ€€ใ€Œๅƒ•ใฏใƒฉใ‚ฐใƒ“ใƒผ็ตŒ้จ“ใŒใชใ„ใงใ™ใ—ใ€ใƒฉใ‚ฐใƒ“ใƒผใ‚’ๅ…จ็„ถ็Ÿฅใ‚‰ใชใ‹ใฃใŸใ‘ใฉใ€ใ‚„ใฃใฑใ‚Šๆœฌไบบใ‹ใ‚‰ใƒฆใƒ‹ใƒ›ใƒผใƒ ใ‚’้ ‚ใ„ใฆใ‚‹ใฃใฆใ„ใ†โ€œๅฐ็ฑ ๏ผˆใ„ใ‚“ใ‚ใ†๏ผ‰โ€ใฟใŸใ„ใชใฎใŒใ‚ใฃใฆใ€‚ใ€Žใ‚ใ„ใคใฏใƒชใƒผใƒใ•ใ‚“ๆœฌไบบใซ่ชใ‚ใ‚‰ใ‚Œใฆใ‚‹ใ€ใจใ€‚ไธ€็›ฎ็ฝฎใ‹ใ‚Œใฆใ„ใ‚‹ใฎใ‹ใชใจๆ„Ÿใ˜ใพใ™ใ€
ใ€€ใ€Œใ‚„ใฃใฆใ„ใ‚‹ใ“ใจใฏใ€่ฆ‹ใŸ็›ฎใ‚’ๆœฌไบบใซๅฏ„ใ›ใฆใƒฏใƒณใƒใƒผใƒ ใฃใฆ่จ€ใ†ใ ใ‘ใชใ‚“ใงใ™ใ‘ใฉใญใ€‚ใใ‚Œใงใ‚‚ใ€Žใ‚ใ‚ใ€ใƒชใƒผใƒใ•ใ‚“ใ ใ€ใจ่จ€ใฃใฆใ‚‚ใ‚‰ใˆใพใ™ใ€
ใ€€ใ€Œใƒชใƒผใƒใ•ใ‚“ใจๅฎŸ้š›ใซไผšใ†ใ“ใจใชใ‚“ใฆใ€็ฐกๅ˜ใซใฏใงใใชใ„ใ˜ใ‚ƒใชใ„ใงใ™ใ‹ใ€‚ใงใ‚‚ใ€ใƒชใƒผใƒใ•ใ‚“ใฎใพใญใ‚’ใ—ใฆใ„ใ‚‹RGใซใฏไผšใˆใŸใ‚ใ€ใฟใŸใ„ใช๏ผˆ็ฌ‘๏ผ‰ใ€‚ไฝ•ใ ใ‚ใ†ใชใ€ๆœ‰ๅใช็ฅž็คพใฎๆ”ฏ็คพใฎใ‚ˆใ†ใชๅญ˜ๅœจใงใ™ใ‹ใญใ€‚ใ‚ใ‚ŠใŒใŸใŒใ‚‰ใ‚Œใ‚‹ใจใ„ใ†ๆ„ๅ‘ณใงใฏไป–ใฎใƒขใƒŽใƒžใƒใจใฏใ™ใ”ใ้•ใ„ใพใ™ใญใ€
"""
test_question = "ใ€€ใƒชใƒผใƒใƒปใƒžใ‚คใ‚ฑใƒซใฏไฝ•ใ‚’้€ใฃใฆใใพใ—ใŸใ‹๏ผŸ"
pipe(do_closed_qa(test_article, question), max_new_tokens=128, temperature=0)[0]["generated_text"]
# "ใƒฆใƒ‹ใƒ›ใƒผใƒ 2็€ใจใƒ‘ใƒณใƒ„ใ‚„ใ‚ฝใƒƒใ‚ฏใ‚นใชใฉ"
```
# Training details
We trained using the following three minimalistic prompt templates for the three tasks in STX:
* SNOW
```python
f"""ๅ…ƒใฎๆ—ฅๆœฌ่ชž๏ผš
{original_ja}
ใ‚ทใƒณใƒ—ใƒซใชๆ—ฅๆœฌ่ชž๏ผš"""
```
* TyDiQA
```python
f"""{passage_text}
{question_text}"""
```
* XLSum
```python
f"""่จ˜ไบ‹๏ผš
{article_text}
่ฆ็ด„๏ผš"""
```
This model was trained for 1000 steps (1.2 epochs) with the model being evaluated every 50 steps. We then chose the best model from these evaluations based on validation loss.
We used the [qlora](https://github.com/artidoro/qlora) package from artidoro.
We trained with the following hyperparameters:
```
Per device evaluation batch size: 16
Per device train batch size: 8
LoRA (lora_r): 64
LoRA alpha (lora_alpha): 16
LoRA modules: all
Double quantization: Enabled
Quantization type: nf4
BF16: Enabled
Bits: 4
Warmup ratio: 0.03
Learning rate scheduler type: Constant
Gradient checkpointing: Enabled
Gradient accumulation steps: 2
Learning rate: 0.0002
Adam beta2: 0.999
Maximum gradient norm: 0.3
LoRA dropout: 0.05
Weight decay: 0.0
```
![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/UWiE7z5tG8t_vdSFrb5WC.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/_fKBf9sdq9UAKKYMxM6ad.png)