Update README.md
Browse files
README.md
CHANGED
@@ -148,13 +148,12 @@ Used RunPod with following setup:
|
|
148 |
optim="paged_adamw_8bit",
|
149 |
save_strategy="steps",
|
150 |
)
|
151 |
-
|
152 |
-
bnb_config = BitsAndBytesConfig(
|
153 |
load_in_4bit=True,
|
154 |
bnb_4bit_use_double_quant=True,
|
155 |
bnb_4bit_quant_type="nf4",
|
156 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
157 |
-
)
|
158 |
<!-- #### Speeds, Sizes, Times [optional] -->
|
159 |
|
160 |
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
@@ -269,14 +268,18 @@ import torch
|
|
269 |
from transformers import (
|
270 |
AutoModelForCausalLM,
|
271 |
AutoTokenizer,
|
|
|
272 |
)
|
273 |
|
|
|
274 |
instruction = (
|
275 |
"Generate Cypher statement to query a graph database. "
|
276 |
"Use only the provided relationship types and properties in the schema. \n"
|
277 |
"Schema: {schema} \n Question: {question} \n Cypher output: "
|
278 |
)
|
279 |
|
|
|
|
|
280 |
def prepare_chat_prompt(question, schema) -> list[dict]:
|
281 |
chat = [
|
282 |
{
|
@@ -301,17 +304,26 @@ def _postprocess_output_cypher(output_cypher: str) -> str:
|
|
301 |
return output_cypher
|
302 |
|
303 |
# Model
|
304 |
-
base_model_name = "google/gemma-2-9b-it"
|
305 |
model_name = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
# Question
|
311 |
question = "What are the movies of Tom Hanks?"
|
312 |
schema = "(:Actor)-[:ActedIn]->(:Movie)"
|
313 |
new_message = prepare_chat_prompt(question=question, schema=schema)
|
314 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
315 |
prompt = tokenizer.apply_chat_template(new_message, add_generation_prompt=True, tokenize=False)
|
316 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
317 |
|
@@ -333,5 +345,5 @@ with torch.no_grad():
|
|
333 |
outputs = [_postprocess_output_cypher(output) for output in raw_outputs]
|
334 |
|
335 |
print(outputs)
|
336 |
-
> ["MATCH (
|
337 |
```
|
|
|
148 |
optim="paged_adamw_8bit",
|
149 |
save_strategy="steps",
|
150 |
)
|
151 |
+
bnb_config = BitsAndBytesConfig(
|
|
|
152 |
load_in_4bit=True,
|
153 |
bnb_4bit_use_double_quant=True,
|
154 |
bnb_4bit_quant_type="nf4",
|
155 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
156 |
+
)
|
157 |
<!-- #### Speeds, Sizes, Times [optional] -->
|
158 |
|
159 |
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
|
|
268 |
from transformers import (
|
269 |
AutoModelForCausalLM,
|
270 |
AutoTokenizer,
|
271 |
+
BitsAndBytesConfig,
|
272 |
)
|
273 |
|
274 |
+
|
275 |
instruction = (
|
276 |
"Generate Cypher statement to query a graph database. "
|
277 |
"Use only the provided relationship types and properties in the schema. \n"
|
278 |
"Schema: {schema} \n Question: {question} \n Cypher output: "
|
279 |
)
|
280 |
|
281 |
+
|
282 |
+
|
283 |
def prepare_chat_prompt(question, schema) -> list[dict]:
|
284 |
chat = [
|
285 |
{
|
|
|
304 |
return output_cypher
|
305 |
|
306 |
# Model
|
|
|
307 |
model_name = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
|
308 |
+
bnb_config = BitsAndBytesConfig(
|
309 |
+
load_in_4bit=True,
|
310 |
+
bnb_4bit_use_double_quant=True,
|
311 |
+
bnb_4bit_quant_type="nf4",
|
312 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
313 |
+
)
|
314 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
315 |
+
model = AutoModelForCausalLM.from_pretrained(
|
316 |
+
model_name,
|
317 |
+
quantization_config=bnb_config,
|
318 |
+
torch_dtype=torch.bfloat16,
|
319 |
+
attn_implementation="eager",
|
320 |
+
low_cpu_mem_usage=True,
|
321 |
+
)
|
322 |
|
323 |
# Question
|
324 |
question = "What are the movies of Tom Hanks?"
|
325 |
schema = "(:Actor)-[:ActedIn]->(:Movie)"
|
326 |
new_message = prepare_chat_prompt(question=question, schema=schema)
|
|
|
327 |
prompt = tokenizer.apply_chat_template(new_message, add_generation_prompt=True, tokenize=False)
|
328 |
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
329 |
|
|
|
345 |
outputs = [_postprocess_output_cypher(output) for output in raw_outputs]
|
346 |
|
347 |
print(outputs)
|
348 |
+
> ["MATCH (a:Actor {Name: 'Tom Hanks'})-[:ActedIn]->(m:Movie) RETURN m"]
|
349 |
```
|