PEFT
Safetensors
English
micheleriva commited on
Commit
26bbcaa
·
verified ·
1 Parent(s): 5319897

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -17
README.md CHANGED
@@ -33,15 +33,10 @@ It understands various data types and query operators, making it versatile for d
33
  ## Usage
34
 
35
  ```python
36
- from transformers import AutoModelForCausalLM, AutoTokenizer
37
  from peft import PeftModel
 
38
 
39
- # Load the model and tokenizer
40
- model_name = "OramaSearch/query-translator-mini"
41
- tokenizer = AutoTokenizer.from_pretrained(model_name)
42
- model = AutoModelForCausalLM.from_pretrained(model_name)
43
-
44
- # System Prompt used during training
45
  SYSTEM_PROMPT = """
46
  You are a tool used to generate synthetic data of Orama queries. Orama is a full-text, vector, and hybrid search engine.
47
 
@@ -76,36 +71,57 @@ The rules to generate the query are:
76
  - Nested properties are supported. Just translate them into dot notation. Example: `{ "where": { "author.name": "John" } }`.
77
  - Array of numbers are not supported.
78
  - Array of booleans are not supported.
 
 
79
  """
80
 
81
- # Example query
82
- query = "What are the red wines that cost less than 20 dollars?"
83
 
84
- # Orama schema
85
- schema = {
86
- "name": "string",
87
- "content": "string",
88
  "price": "number",
89
- "tags": "enum[]"
90
  }
91
 
92
- # Generate structured query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  messages = [
94
  {"role": "system", "content": SYSTEM_PROMPT},
95
- {"role": "user", "content": f"Query: {query}\nSchema: {json.dumps(schema)}"},
96
  ]
97
 
98
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
  outputs = model.generate(
101
  **inputs,
102
- max_length=512,
 
103
  temperature=0.1,
104
  top_p=0.9,
105
  num_return_sequences=1,
106
  )
107
 
108
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
109
  ```
110
 
111
  ## Training Details
 
33
  ## Usage
34
 
35
  ```python
36
+ import json, torch
37
  from peft import PeftModel
38
+ from transformers import AutoModelForCausalLM, AutoTokenizer
39
 
 
 
 
 
 
 
40
  SYSTEM_PROMPT = """
41
  You are a tool used to generate synthetic data of Orama queries. Orama is a full-text, vector, and hybrid search engine.
42
 
 
71
  - Nested properties are supported. Just translate them into dot notation. Example: `{ "where": { "author.name": "John" } }`.
72
  - Array of numbers are not supported.
73
  - Array of booleans are not supported.
74
+
75
+ Return just a JSON object, nothing more.
76
  """
77
 
78
+ QUERY = "Show me some wine reviews with a score greater than 4.5 and less than 5.0."
 
79
 
80
+ SCHEMA = {
81
+ "title": "string",
82
+ "description": "string",
 
83
  "price": "number",
 
84
  }
85
 
86
+ base_model_name = "Qwen/Qwen2.5-7B"
87
+ adapter_path = "OramaSearch/query-translator-mini"
88
+
89
+ print("Loading tokenizer...")
90
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
91
+
92
+ print("Loading base model...")
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ base_model_name,
95
+ torch_dtype=torch.float16,
96
+ device_map="auto",
97
+ trust_remote_code=True,
98
+ )
99
+
100
+ print("Loading fine-tuned adapter...")
101
+ model = PeftModel.from_pretrained(model, adapter_path)
102
+
103
+ if torch.cuda.is_available():
104
+ model = model.cuda()
105
+ print(f"GPU memory after loading: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
106
+
107
  messages = [
108
  {"role": "system", "content": SYSTEM_PROMPT},
109
+ {"role": "user", "content": f"Query: {QUERY}\nSchema: {json.dumps(SCHEMA)}"},
110
  ]
111
 
112
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
113
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
114
  outputs = model.generate(
115
  **inputs,
116
+ max_new_tokens=512,
117
+ do_sample=True,
118
  temperature=0.1,
119
  top_p=0.9,
120
  num_return_sequences=1,
121
  )
122
 
123
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
+ print(response)
125
  ```
126
 
127
  ## Training Details