Mahalingam commited on
Commit
23975dc
·
1 Parent(s): ff2e42d
Files changed (2) hide show
  1. fine_tune_model.py +111 -0
  2. testgensummary.py +45 -0
fine_tune_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer, TrainingArguments, T5ForConditionalGeneration, T5Tokenizer
2
+ from datasets import load_dataset
3
+ import base64
4
+ import json
5
+ from pathlib import Path
6
+
7
+ # Load your dataset
8
+ dataset = load_dataset("./files/")
9
+ # Assuming your dataset has 'train' split
10
+ train_dataset = dataset["train"]
11
+ # Load the T5 model and tokenizer from a local directory
12
+ model_path = "t5-small-model"
13
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
14
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
15
+
16
+ # Define the training arguments
17
+ training_args = TrainingArguments(
18
+ output_dir="./output1", # Specify the output directory for model checkpoints and predictions
19
+ save_steps=100,
20
+ per_device_train_batch_size=4, # Adjust the batch size based on your GPU memory
21
+ save_total_limit=2, # Limit the total number of checkpoints to save
22
+ num_train_epochs=3, # Specify the number of training epochs
23
+ logging_dir="./logs", # Specify the directory for Tensorboard logs
24
+ )
25
+
26
+ # Define format_dataset function
27
+ def format_dataset(file_path):
28
+ with open(file_path, 'r', encoding='utf-8') as f:
29
+ content = f.read()
30
+ print(f"File content:\n{content}\n")
31
+
32
+ try:
33
+ data_list = json.loads(content)
34
+ except json.JSONDecodeError as e:
35
+ print(f"Error decoding JSON: {e}")
36
+ return None
37
+
38
+ formatted_examples = []
39
+
40
+ for data in data_list:
41
+ input_texts = data.get("input")
42
+ targets = data.get("target")
43
+
44
+ # Convert to lists if not already
45
+ if not isinstance(input_texts, list):
46
+ input_texts = [input_texts]
47
+ if not isinstance(targets, list):
48
+ targets = [targets]
49
+
50
+ # Concatenate the texts in the list
51
+ input_text_concatenated = " ".join(input_texts)
52
+ target_text_concatenated = " ".join(targets)
53
+
54
+ # Encode concatenated texts
55
+ #inputs = tokenizer(input_text_concatenated, padding=True, truncation=True, return_tensors="pt", max_length=512)
56
+ #labels = tokenizer(target_text_concatenated, padding=True, truncation=True, return_tensors="pt", max_length=512)
57
+ # Encode concatenated texts with padding and truncation
58
+ inputs = tokenizer(
59
+ input_text_concatenated,
60
+ padding="max_length",
61
+ truncation=True,
62
+ return_tensors="pt",
63
+ max_length=512
64
+ )
65
+ labels = tokenizer(
66
+ target_text_concatenated,
67
+ padding="max_length",
68
+ truncation=True,
69
+ return_tensors="pt",
70
+ max_length=512
71
+ )
72
+
73
+
74
+ # Update the inputs dictionary with the labels
75
+ inputs["labels"] = labels["input_ids"]
76
+
77
+ formatted_examples.append(inputs)
78
+
79
+ return formatted_examples
80
+
81
+ # Process each example individually
82
+ data_files = Path("./files/").rglob("*.json")
83
+ formatted_examples = [format_dataset(file_path) for file_path in data_files if format_dataset(file_path) is not None]
84
+
85
+ # Flatten the list of examples
86
+ formatted_examples = [example for sublist in formatted_examples for example in sublist]
87
+
88
+ # Create the final dataset
89
+ train_dataset = [{"input_ids": example["input_ids"][0], "attention_mask": example["attention_mask"][0], "labels": example["labels"][0]} for example in formatted_examples]
90
+
91
+
92
+ # Instantiate the Trainer with save_tokenizer
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=training_args,
96
+ train_dataset=train_dataset, #dataset["train"],
97
+ tokenizer=tokenizer, # Pass the tokenizer to the Trainer
98
+ # ... other Trainer configurations ...
99
+ )
100
+
101
+ print(f"Number of examples in the training dataset: {len(dataset['train'])}")
102
+
103
+ # Print model configuration
104
+ print("Model Configuration:")
105
+ print(model.config)
106
+ # Training loop
107
+ trainer.train()
108
+
109
+ # Save the model after training
110
+ model.save_pretrained("./output/fine-tuned-model")
111
+ tokenizer.save_pretrained("./output/fine-tuned-model")
testgensummary.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
+
3
+ # Load the fine-tuned model and tokenizer
4
+ model_name = "C:\\fine-tuned-model"
5
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
6
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
7
+
8
+ # Prompt
9
+ prompt = """Write a medical summary in detailed way with patient details like Sex, Age and medical details in a paragraph format from the below data
10
+ {
11
+
12
+ "Sex": "M",
13
+ "ID": 585248,
14
+ "DateOfBirth": "08/10/1995",
15
+ "Age": "28 years",
16
+ "VisitDate": "09/25/2023",
17
+ "LogNumber": 6418481,
18
+ "Historian": "Self",
19
+ "TriageNotes": ["fever"],
20
+ "HistoryOfPresentIllness": {
21
+ "Complaint": [
22
+ "The patient presents with a chief complaint of chills.",
23
+ "The problem is made better by exercise and rest.",
24
+ "The patient also reports change in appetite and chest pain/pressure as abnormal symptoms related to the complaint."
25
+ ]
26
+ }
27
+ }"""
28
+
29
+ # Tokenize and generate text with sampling and different decoding parameters
30
+ input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=512)
31
+ generated_text = model.generate(
32
+ input_ids,
33
+ max_length=200,
34
+ num_beams=5,
35
+ temperature=0.9, # Adjust the temperature for more randomness
36
+ no_repeat_ngram_size=2,
37
+ top_k=50,
38
+ top_p=0.95,
39
+ early_stopping=True,
40
+ do_sample=True,
41
+ )
42
+
43
+ # Decode and print the generated text
44
+ decoded_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)
45
+ print(f"Generated Text: {decoded_text}")