Mistral-7B fine-tuned on AgentInstruct
Mistral-7b-v1.0 fine-tuned on the dataset AgentInstruct for "better acting as an agent"
Model Details
Model Description
The Mistral-7B-v0.1 Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters. Mistral-7B-v0.1 outperforms Llama 2 13B on all benchmarks we tested.
For full details of this model please read our paper and release blog post.
Model Architecture
Mistral-7B-v0.1 is a transformer model, with the following architecture choices:
- Grouped-Query Attention
- Sliding-Window Attention
- Byte-fallback BPE tokenizer
Dataset Details
AgentInstruct is a meticulously curated dataset featuring 1,866 high-quality interactions, designed to enhance AI agents across six diverse real-world tasks, leveraging innovative methods like Task Derivation and Self-Instruct.
- ๐ CoT - Harness the power of ReAct, offering detailed thought explanations for each action, ensuring an intricate understanding of the model's decision-making journey.
- ๐ Diversity - Spanning 6 real-world scenarios, from Daily Household Routines to Database Operations, and their average turns range from 5 to 35.
- ๐ฏ Precision - Not all trajectories of GPT-4 are effective! Ours are rigorously filtered using strict rewards to ensure top-notch quality.
- โ Assurance - Rigorous checks to avoid data leakage, ensuring pristine dataset quality.
Task Overview
Task | # Filt. Traj. | Avg # Filt. Traj. Turns |
---|---|---|
ALFWorld | 336 | 13.52 |
WebShop | 351 | 3.68 |
Mind2Web | 122 | 1.00 |
Knowledge Graph | 324 | 6.04 |
Operating System | 195 | 3.85 |
Database | 538 | 2.06 |
AgentInstruct | 1866 | 5.24 |
AgentInstruct includes 1,866 trajectories from 6 agents tasks. "Traj." stands for interaction trajectory. "Filt. Traj." stands for filtered trajectories.
Training Details
TBD
Example of usage
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
import torch
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct").to("cuda")
class MyStoppingCriteria(StoppingCriteria):
def __init__(self, target_sequence, prompt):
self.target_sequence = target_sequence
self.prompt = prompt
def __call__(self, input_ids, scores, **kwargs):
# Decode without prompt and check for target sequence
generated_text = tokenizer.decode(input_ids[0]).replace(self.prompt, '')
return self.target_sequence in generated_text
def __len__(self):
return 1
def generate(context, max_new_tokens=256, min_new_tokens=64, temperature=0.3, top_p=0.75, top_k=40, do_sample=True, num_beams=2):
# Prepare input data
inputs = tokenizer(context, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
attention_mask = inputs["attention_mask"].to("cuda")
# Generation settings
generation_settings = {
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": do_sample,
"num_beams": num_beams,
"early_stopping": False,
"use_cache": True,
"stopping_criteria": MyStoppingCriteria("### human:", context)
}
# Generate response
with torch.no_grad():
generation_output = model.generate(input_ids, attention_mask, **generation_settings)
output = tokenizer.decode(generation_output.sequences[0])
return output
# Example usage
context = ""
human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status.
There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE."""
context = human
solution = generate(context)
print(solution)
Citation
@misc {manuel_romero_2024,
author = { {Manuel Romero} },
title = { mistral-7b-ft-AgentInstruct (Revision 463b96d) },
year = 2024,
url = { https://huggingface.co/mrm8488/mistral-7b-ft-AgentInstruct },
doi = { 10.57967/hf/1650 },
publisher = { Hugging Face }
}
- Downloads last month
- 13