danielhajialigol commited on
Commit
1841ebe
·
1 Parent(s): eca4ff8

fixed model issue

Browse files
Files changed (4) hide show
  1. all_summaries.csv +2 -2
  2. app.py +17 -3
  3. discharge_embeddings.pt +2 -2
  4. model.py +7 -1
all_summaries.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3c74c03c1a4b5ad01eff9eea8a7062c660d342bf912794191cd5e3ebeb4abe44
3
- size 1114819287
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64b8415ab96d12e61393c8ca12ebb844ab32a57df314984e98e91e1064bebf41
3
+ size 640698121
app.py CHANGED
@@ -4,9 +4,14 @@ import pandas as pd
4
  import torch
5
 
6
  from model import MimicTransformer
7
- from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn
8
  from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
 
 
 
 
 
 
10
  model_path = 'checkpoint_0_9113.bin'
11
  related_tensor = torch.load('discharge_embeddings.pt')
12
  all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
@@ -16,7 +21,9 @@ similarity_model = AutoModel.from_pretrained('kamalkraj/BioSimCSE-BioLinkBERT-BA
16
  similarity_model.eval()
17
 
18
  def read_model(model, path):
19
- model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
 
 
20
  return model
21
 
22
  mimic = MimicTransformer(cutoff=512)
@@ -50,8 +57,10 @@ def mean_pooling(model_output, attention_mask):
50
 
51
 
52
  def get_model_results(text):
 
53
  inputs = tokenizer(text, return_tensors='pt', padding='max_length', max_length=512, truncation=True)
54
- outputs = mimic(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, drg_labels=None)
 
55
  attribution, reconstructed_text = get_attribution(text=text, tokenizer=tokenizer, model_outputs=outputs, inputs=inputs, k=10)
56
  logits = outputs[0][0]
57
  out = logits.detach().cpu()[0]
@@ -93,7 +102,12 @@ def find_related_summaries(text):
93
 
94
 
95
  def run(text, related_discharges=False):
 
 
 
 
96
  # initial drg results
 
97
  model_results = get_model_results(text=text)
98
  drg_code = model_results['class']
99
 
 
4
  import torch
5
 
6
  from model import MimicTransformer
7
+ from utils import load_rule, get_attribution, get_diseases, get_drg_link, get_icd_annotations, visualize_attn, clean_text
8
  from transformers import AutoTokenizer, AutoModel, set_seed, pipeline
9
 
10
+ torch.manual_seed(0)
11
+ set_seed(34)
12
+ if torch.cuda.is_available():
13
+ torch.cuda.manual_seed_all(0)
14
+
15
  model_path = 'checkpoint_0_9113.bin'
16
  related_tensor = torch.load('discharge_embeddings.pt')
17
  all_summaries = pd.read_csv('all_summaries.csv')['SUMMARIES'].to_list()
 
21
  similarity_model.eval()
22
 
23
  def read_model(model, path):
24
+ # model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
25
+ state_dict = torch.load(path, map_location='cpu')
26
+ model.load_state_dict({"model."+k: v for k, v in state_dict.items()}, strict=False)
27
  return model
28
 
29
  mimic = MimicTransformer(cutoff=512)
 
57
 
58
 
59
  def get_model_results(text):
60
+ text = clean_text(text)
61
  inputs = tokenizer(text, return_tensors='pt', padding='max_length', max_length=512, truncation=True)
62
+ with torch.no_grad():
63
+ outputs = mimic(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, drg_labels=None)
64
  attribution, reconstructed_text = get_attribution(text=text, tokenizer=tokenizer, model_outputs=outputs, inputs=inputs, k=10)
65
  logits = outputs[0][0]
66
  out = logits.detach().cpu()[0]
 
102
 
103
 
104
  def run(text, related_discharges=False):
105
+ torch.manual_seed(0)
106
+ set_seed(34)
107
+ if torch.cuda.is_available():
108
+ torch.cuda.manual_seed_all(0)
109
  # initial drg results
110
+
111
  model_results = get_model_results(text=text)
112
  drg_code = model_results['class']
113
 
discharge_embeddings.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2179abc7e448f3bb4a091f4f27eadc5b1d4829464eeb9bfb7fd9c6363844aaaa
3
- size 1228800786
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbc05e83aa36756a35bee2f104e3c3dcc8fb1f26442d89ff52916d7052cd036b
3
+ size 713869074
model.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
2
  from torch.utils.data import DataLoader
3
  from torch.nn import Linear, Module
4
  from typing import Dict, List
@@ -6,6 +6,11 @@ from collections import Counter, defaultdict
6
  from itertools import chain
7
  import torch
8
 
 
 
 
 
 
9
  class MimicTransformer(Module):
10
  def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512):
11
  """
@@ -17,6 +22,7 @@ class MimicTransformer(Module):
17
  self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels)
18
  self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config)
19
  self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config)
 
20
  if 'longformer' in self.tokenizer_name:
21
  self.cutoff = self.model.config.max_position_embeddings
22
  else:
 
1
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, set_seed
2
  from torch.utils.data import DataLoader
3
  from torch.nn import Linear, Module
4
  from typing import Dict, List
 
6
  from itertools import chain
7
  import torch
8
 
9
+ torch.manual_seed(0)
10
+ set_seed(34)
11
+ if torch.cuda.is_available():
12
+ torch.cuda.manual_seed_all(0)
13
+
14
  class MimicTransformer(Module):
15
  def __init__(self, num_labels=738, tokenizer_name='clinical', cutoff=512):
16
  """
 
22
  self.config = AutoConfig.from_pretrained(self.tokenizer_name, num_labels=self.num_labels)
23
  self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, config=self.config)
24
  self.model = AutoModelForSequenceClassification.from_pretrained(self.tokenizer_name, config=self.config)
25
+ self.model.eval()
26
  if 'longformer' in self.tokenizer_name:
27
  self.cutoff = self.model.config.max_position_embeddings
28
  else: