santhosh97 commited on
Commit
6e9ff05
1 Parent(s): fe5151f

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -0
README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gretel's baseline text2table was fine-tuned on togethercomputer's RedPajama-INCITE-instruct-3B-v1 model for 100 epochs on 8A100 80GB gpu's. The fine-tuning used ~2k training samples (text and table pairs) that were generated using OpenAI.
2
+
3
+ ## Data Formatting
4
+
5
+ ```python
6
+ INSTRUCTION_KEY = "### Instruction: Given the following prompt, generate a table"
7
+ RESPONSE_KEY = "### Response:"
8
+ INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
9
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
10
+ {instruction_key}
11
+ {prompt_to_generate_table}
12
+ {response_key}
13
+ {table}
14
+ """.format(
15
+ intro=INTRO_BLURB,
16
+ instruction_key=INSTRUCTION_KEY,
17
+ prompt_to_generate_table"{PROMPT}",
18
+ response_key=RESPONSE_KEY,
19
+ table="{TABLE}"
20
+ )
21
+ ```
22
+
23
+ ## For generation purposes:
24
+
25
+ ```python
26
+ import torch
27
+ from transformers import (
28
+ AutoModelForCausalLM,
29
+ AutoTokenizer,
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained('gretelai/text2table', padding_side="right")
32
+ model = AutoModelForCausalLM.from_pretrained('gretelai/text2table').to('cuda', dtype=torch.bfloat16)
33
+
34
+ model.eval()
35
+
36
+ INSTRUCTION_KEY = "### Instruction: Given the following prompt, generate a table. Each column should have random values."
37
+ RESPONSE_KEY = "### Response:"
38
+ INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
39
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
40
+ {instruction_key}
41
+ {prompt_to_generate_table}
42
+ {response_key}
43
+ """.format(
44
+ intro=INTRO_BLURB,
45
+ instruction_key=INSTRUCTION_KEY,
46
+ prompt_to_generate_table="{PROMPT}",
47
+ response_key=RESPONSE_KEY,
48
+ )
49
+
50
+ PROMPT = "Create a dataset with four columns: patient, sex, agegrp, bp_before and bp_after. The patient column is a numerical identifier, sex is the gender of the patient, agegrp is the age group of the patient, bp_before is the blood pressure (in mmHg) before a certain treatment, and bp_after is the blood pressure (in mmHg) after a certain treatment."
51
+ inputs = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
52
+ tokenizer.pad_token = tokenizer.eos_token
53
+ input = tokenizer(inputs, return_tensors="pt").to('cuda')
54
+ input_ids = input['input_ids']
55
+ outputs = model.generate(**input, max_length = 1024)
56
+ table = tokenizer.decode(outputs[0], skip_special_tokens=False)
57
+ ```
58
+
59
+ ## Output
60
+
61
+ ```python
62
+ PROMPT = "Create a dataset with four columns: patient, sex, agegrp, bp_before and bp_after. The patient column is a numerical identifier, sex is the gender of the patient, agegrp is the age group of the patient, bp_before is the blood pressure (in mmHg) before a certain treatment, and bp_after is the blood pressure (in mmHg) after a certain treatment."
63
+
64
+ MODEL GENERATION ->
65
+
66
+ Below is an instruction that describes a task. Write a response that appropriately completes the request.
67
+ Instruction: Given the following prompt, generate a table. Each column should have random values.
68
+ Create a dataset with four columns: patient, sex, agegrp, bp_before and bp_after. The patient column is a numerical identifier, sex is the gender of the patient, agegrp is the age group of the patient, bp_before is the blood pressure (in mmHg) before a certain treatment, and bp_after is the blood pressure (in mmHg) after a certain treatment.
69
+ Response:
70
+ patient,sex,agegrp,bp_before,bp_after
71
+ 1.0,F,45.0,183.0,124.0,234.0
72
+ 2.0,F,60.0,183.0,124.0,183.0
73
+ 3.0,F,70.0,179.0,117.0,183.0
74
+ 4.0,M,30.0,141.0,136.0,161.0
75
+ 5.0,M,70.0,147.0,129.0,157.0
76
+ 6.0,M,40.0,140.0,136.0,156.0
77
+ 7.0,M,60.0,140.0,116.0,157.0
78
+ 8.0,M,70.0,144.0,131.0,161.0
79
+ 9.0,M,60.0,142.0,119.0,157.0
80
+ 10.0,M,70.0,147.0,132.0,167.0
81
+ 11.0,M,60.0,147.0,136.0,166.0
82
+ 12.0,M,70.0,150.0,132.0,172.0
83
+ 13.0,M,60.0,149.0,137.0,162.0
84
+ 14.0,M,70.0,156.0,124.0,157.0
85
+ 15.0,M,60.0,156.0,181.0,157.0
86
+ 16.0,M,70.0,156.0,131.0,158.0
87
+ ```