|
--- |
|
license: mit |
|
language: |
|
- en |
|
base_model: |
|
- meta-llama/Llama-3.1-8B-Instruct |
|
pipeline_tag: token-classification |
|
--- |
|
|
|
<div align="center"> |
|
<h1> |
|
MedSSS-8B-PRM |
|
</h1> |
|
</div> |
|
|
|
<div align="center"> |
|
<a href="https://github.com/pixas/MedSSS" target="_blank">GitHub</a> | <a href="" target="_blank">Paper</a> |
|
</div> |
|
|
|
# <span>Introduction</span> |
|
**MedSSS-PRM** is a the PRM model designed for slow-thinking medical reasoning. It will assign a `[0-1]` float value for every internal reasoning step of **MedSSS-Policy**. |
|
|
|
For more information, visit our GitHub repository: |
|
[https://github.com/pixas/MedSSS](https://github.com/pixas/MedSSS). |
|
|
|
|
|
|
|
|
|
# <span>Usage</span> |
|
We build the PRM model as a LoRA adapter, which saves the memory to use it. |
|
As this LoRA adapter is built on `Meta-Llama3.1-8B-Instruct`, you need to first prepare the base model in your platform. |
|
|
|
```python |
|
|
|
def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs): |
|
# `outputs` generated by the MedSSS-Policy |
|
response = outputs |
|
completions = [f"Step" + completion if not completion.startswith("Step") else completion for k, completion in enumerate(outputs.split("\n\nStep"))] |
|
|
|
messages = [ |
|
{"role": "user", "content": inputs}, |
|
{"role": "assistant", "content": response} |
|
] |
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False) |
|
|
|
response_begin_index = input_text.index(response) |
|
|
|
pre_response_input = input_text[:response_begin_index] |
|
after_response_input = input_text[response_begin_index + len(response):] |
|
completion_ids = [ |
|
tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions |
|
] |
|
|
|
response_id = list(chain(*completion_ids)) |
|
pre_response_id = tokenizer(pre_response_input, add_special_tokens=False)['input_ids'] |
|
after_response_id = tokenizer(after_response_input, add_special_tokens=False)['input_ids'] |
|
|
|
|
|
input_ids = pre_response_id + response_id + after_response_id |
|
|
|
value = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device)) # [1, N] |
|
|
|
completion_index = [] |
|
for i, completion in enumerate(completion_ids): |
|
if i == 0: |
|
completion_index.append(len(completion) + len(pre_response_id) - 1) |
|
else: |
|
completion_index.append(completion_index[-1] + len(completion)) |
|
|
|
step_value = value[0, completion_index].cpu().numpy().tolist() |
|
return step_value |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
from peft import PeftModel |
|
base_model = AutoModelForTokenClassification.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",torch_dtype="auto",device_map="auto") |
|
model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM") |
|
steps |
|
input_text = "How to stop a cough?" |
|
step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]" |
|
|
|
value = obtain_prm_value_for_single_pair(tokenizer, model, input_text, step_wise_generation) |
|
print(value) |
|
``` |
|
|
|
MedSSS-PRM uses "\n\nStep" to separate intermediate steps. So the token classification happens before the next "Step k: " or the end of the sequence. |