|
--- |
|
language: en |
|
license: apache-2.0 |
|
datasets: |
|
- derek-thomas/ScienceQA |
|
- allenai/ai2_arc |
|
tags: |
|
- education |
|
- stem |
|
- computer science |
|
- data science |
|
- engineering |
|
- biology |
|
- chemistry |
|
--- |
|
|
|
|
|
# STEMerald-2b |
|
|
|
**Model name:** STEMerald-2b |
|
|
|
**Model description:** |
|
STEMerald-2b is a fine-tuned version of the Gemma-2b model, designed specifically for answering university-level STEM multiple-choice questions. This model leverages advanced fine-tuning techniques, including Supervised Fine-Tuning (SFT) and Direct Preference Optimization (DPO), to enhance its accuracy and reliability in providing educational support. |
|
|
|
<p align="center"> |
|
<img src="STEMerald_pic.jpeg" alt="STEMerald picture" width="400"/> |
|
</p> |
|
|
|
## Model Details |
|
|
|
**Base Model:** [Gemma-2b](https://arxiv.org/abs/2403.08295) |
|
|
|
**Architecture:** Decoder-only Language Model (Causal) |
|
|
|
**Parameters:** 2.51 billion |
|
|
|
**Quantized Version:** STEMerald-2b-4bit (with 4-bit NormalFloat) |
|
|
|
**Training Framework:** PyTorch with Hugging Face Transformers |
|
|
|
## Datasets |
|
|
|
The model was fine-tuned on a variety of datasets tailored for STEM education, including: |
|
|
|
- **EPFL Preference Pairs Dataset:** 1522 university-level STEM questions with 26k preference pairs, annotated by students using ChatGPT-3.5 with Chain-of-Thought (CoT). |
|
- **Stack Exchange Dataset:** Questions and answers from various topics such as math, computer science, and engineering. |
|
- **Orca-Math:** 200k grade-school math word problems to enhance reasoning capabilities. |
|
- **EPFL MCQA Dataset**: Dataset of multiple-choice questions with explanation (for CoT) extracted from the winning pairs of EPFL preference pairs. |
|
- **ScienceQA:** Multiple-choice questions on biology, physics, chemistry, economics, earth science, and engineering practices. |
|
- **AI2 Reasoning Challenge (ARC):** Grade-school level multiple-choice science questions. |
|
|
|
## Training Process |
|
|
|
The training process for STEMerald-2b involved multiple steps: |
|
|
|
1. **Supervised Fine-Tuning (SFT):** Initial training on datasets like Orca-Math to improve reasoning abilities. |
|
2. **Direct Preference Optimization (DPO):** Training on preference pairs from EPFL and Stack Exchange datasets to align model outputs with preferred answers. |
|
3. **MCQA Fine-Tuning:** Specialization for multiple-choice question answering using datasets like ScienceQA and ARC. |
|
|
|
## Performance |
|
|
|
The performance of STEMerald-2b was evaluated using various metrics: |
|
|
|
- **Accuracy:** The model achieved high accuracy across multiple test sets, demonstrating its effectiveness in answering STEM questions. |
|
- **Qualitative Evaluation:** The model's answers were evaluated for logical consistency, truthfulness, clarity, and coherence with the final answer. |
|
|
|
### Results |
|
|
|
| Model Version | Accuracy (Non-Quantized) | Accuracy (Quantized) | |
|
|-----------------------------------|--------------------------|----------------------| |
|
| it-ORCA-DPO-MCQA _(STEMerald-2b)_ | 0.750 | 0.720 | |
|
| it-DPO-MCQA | 0.744 | 0.720 | |
|
| it-MCQA | 0.736 | 0.700 | |
|
| it-ORCA-MCQA | 0.722 | 0.714 | |
|
| MCQA | 0.702 | 0.654 | |
|
| DPO-MCQA | 0.694 | 0.674 | |
|
| Gemma-it-OneShot | 0.546 | 0.520 | |
|
| Gemma-it | 0.518 | 0.518 | |
|
|
|
Micro-averaged accuracy over three MCQA test sets(EPFL MCQA, ScienceQA and ARC). |
|
|
|
## Use Cases |
|
|
|
STEMerald-2b can be utilized as a STEM course assistant, providing support in areas such as: |
|
|
|
- Answering university-level multiple-choice STEM questions. |
|
- Offering detailed explanations and reasoning for answers. |
|
- Enhancing student engagement and learning efficiency during independent studies. |
|
|
|
## Ethical Considerations |
|
|
|
While STEMerald-2b aims to provide accurate and helpful responses, it is important to consider potential ethical implications: |
|
|
|
- **Over-Reliance:** Students might become overly dependent on the model for answers, potentially affecting their independent learning and problem-solving skills. |
|
- **Accuracy:** Although efforts were made to ensure the truthfulness of responses, there is still a possibility of incorrect answers. Teacher supervision is crucial. |
|
|
|
## Limitations |
|
|
|
- The model's performance may vary based on the specific context and nature of the questions. |
|
- Quantization reduces memory footprint but may slightly affect accuracy. |
|
|
|
## Conclusion |
|
|
|
STEMerald-2b offers a promising solution for enhancing STEM education through advanced language model capabilities. By leveraging fine-tuning techniques and comprehensive datasets, it aims to provide accurate and accessible learning support for students. |
|
|
|
## How to Use |
|
|
|
You can use the model directly with the `transformers` library: |
|
|
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("matsant01/STEMerald-2b") |
|
model = AutoModelForCausalLM.from_pretrained("matsant01/STEMerald-2b") |
|
|
|
input_text = "Question: What is the derivative of x^2? \nOptions: A. 4x B. 2*x^2 C. 2x D. 2\nAnswer:" |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
outputs = model.generate(**inputs) |
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
``` |
|
|
|
For the quantized version, use: |
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_quant_type="nf4" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("matsant01/STEMerald-2b-4bit") |
|
model = AutoModelForCausalLM.from_pretrained("matsant01/STEMerald-2b-4bit", quantization_config=quantization_config) |
|
``` |
|
|
|
|
|
## Acknowledgements |
|
|
|
We acknowledge the contributions of the EPFL and Stack Exchange communities for their invaluable datasets, and the Hugging Face team for their support and tools that made this project possible. |
|
|
|
## Contact |
|
|
|
For any questions or feedback, please contact: |
|
- [Antonio Mari](https://github.com/antoniomari) (antonio.mari@epfl.ch) |
|
- [Matteo Santelmo](https://github.com/matsant01) (matteo.santelmo@epfl.ch) |
|
- [Stefano Viel](https://github.com/stefanoviel) (stefano.viel@epfl.ch) |