Model Card for shijunju/gemma_7b_finRisk_r6_4VersionQ
The repository includes only a LoRA adapter trained with keras-nlp on a TPU. To use the fine-tuned model, load Gemma-7b and apply load_lora_weights
before generating output (see instructions below).
This model is fine-tuned using the LoRA (Low-Rank Adaptation) approach, specifically designed for question answering in the domain of financial risk compliance.
The Gemma-7b-en model is fine-tuned using documents from fincen.gov.
It is capable of answering questions about documents published on fincen.gov, including Alerts, Advisories, and Financial Trend Analysis reports since 2020.
Model Details
Model Description
- Developed by: Shijun Ju
- Finetuned from model: Gemma-7b-en
- The model is finetuned using Keras_nlp with TPU adapting the approach used at https://www.kaggle.com/code/nilaychauhan/keras-gemma-distributed-finetuning-and-inference
- LoRA rank: 6
- Accuracy rate for unseen paraphrased questions is 70% for default setting and 78% using Beam Search of 3.
Dataset Used
shijunju/fincen_all_questions_5versions
- The model is trained with 4 versions ("question_version" 0, 1, 2, 4) of paraphrased questions and tested on Version 3.
How to Get Started with the Model
Use the code below to get started with the model. (Require TPU!)
# https://www.kaggle.com/code/nilaychauhan/keras-gemma-distributed-finetuning-and-inference
import tensorflow_text as text
model_parallel = keras.distribution.ModelParallel(
device_mesh=device_mesh,
layout_map=layout_map,
batch_dim_name="batch"
)
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
gemma_lm.backbone.enable_lora(rank=6)
gemma_lm.backbone.load_lora_weights(MODEL_LORA_WT_PATH) # the folder and file name you save the downloaded adaptor from this repository
def generate_response(prompt, max_length=256):
outputs = gemma_lm.generate(prompt,
max_length=max_length
)
print(outputs)
inference_template = """<start_of_turn>user\nQuestion: {question}\n<end_of_turn>\n\n<start_of_turn>model\n"""
prompt = inference_template.format(
question="Identify the specific fraudulent scheme highlighted in the FinCEN alert related to requests for convertible virtual currency payments.",
response=""
)
print(generate_response(prompt))
If you encounter any errors, please refer to https://www.kaggle.com/code/nilaychauhan/keras-gemma-distributed-finetuning-and-inference for how to load the model into TPU properly.