mtyrrell commited on
Commit
c8a9cbc
·
1 Parent(s): 9e05a35

truncation bug fix

Browse files
Files changed (3) hide show
  1. .gitignore +0 -3
  2. app.py +1 -3
  3. modules/utils.py +29 -12
.gitignore CHANGED
@@ -4,8 +4,5 @@
4
  *.xlsx
5
  /testing/
6
  /modules/__pycache__/
7
- /logs/
8
  app.log
9
- logs
10
- logs/
11
  /sandbox/
 
4
  *.xlsx
5
  /testing/
6
  /modules/__pycache__/
 
7
  app.log
 
 
8
  /sandbox/
app.py CHANGED
@@ -131,7 +131,7 @@ def main():
131
  logger.info(f"File uploaded: {uploaded_file.name}")
132
 
133
  if not st.session_state['data_processed']:
134
- logger.info("Starting data processing...")
135
  try:
136
  st.session_state['df'] = process_data(uploaded_file, sens_level)
137
  logger.info("Data processing completed successfully")
@@ -141,8 +141,6 @@ def main():
141
  raise
142
 
143
  df = st.session_state['df']
144
- logger.info(f"DataFrame columns: {list(df.columns)}")
145
- logger.info(f"DataFrame shape: {df.shape}")
146
 
147
  current_datetime = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
148
  output_filename = f'processed_applications_{current_datetime}.csv'
 
131
  logger.info(f"File uploaded: {uploaded_file.name}")
132
 
133
  if not st.session_state['data_processed']:
134
+ logger.info("Starting data processing")
135
  try:
136
  st.session_state['df'] = process_data(uploaded_file, sens_level)
137
  logger.info("Data processing completed successfully")
 
141
  raise
142
 
143
  df = st.session_state['df']
 
 
144
 
145
  current_datetime = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
146
  output_filename = f'processed_applications_{current_datetime}.csv'
modules/utils.py CHANGED
@@ -88,32 +88,47 @@ def predict_category(df, model_name, progress_bar, repo, profile, multilabel=Fal
88
  col_name = re.sub(r'_(.*)', r'_txt', model_name)
89
  model = SetFitModel.from_pretrained(profile+"/"+repo)
90
  model.to(device)
 
 
91
  else:
92
  col_name = 'scope_txt'
93
- model = pipeline("text-classification", model=profile+"/"+repo, device=device, return_all_scores=multilabel)
 
 
 
 
 
94
  predictions = []
95
  total = len(df)
96
  for i, text in enumerate(df[col_name]):
97
- prediction = model(text)
98
- if model_name in model_names_sf:
99
- predictions.append(0 if prediction == 'NEGATIVE' else 1)
100
- elif model_name == 'ADAPMIT':
101
- predictions.append(re.sub('Label$', '', prediction[0]['label']))
102
- elif model_name == 'SECTOR':
103
- predictions.append(extract_predicted_labels(prediction[0], threshold=0.5))
104
- elif model_name == 'LANG':
105
- predictions.append(prediction[0]['label'])
 
 
 
 
 
 
 
 
 
106
  # Update progress bar with each iteration
107
  progress = (i + 1) / total
108
  progress_bar.progress(progress)
109
- # st.write(predictions)
110
  return predictions
111
 
112
 
113
  # Main function to process data
114
  def process_data(uploaded_file, sens_level):
115
  df = pd.read_excel(uploaded_file)
116
- logger.info(f"data import successful")
117
  # Rename columns
118
  df.rename(columns={
119
  'id': 'id',
@@ -147,6 +162,7 @@ def process_data(uploaded_file, sens_level):
147
  step_count = 0
148
  total_steps = len(model_names)
149
  for model_name in model_names:
 
150
  step_count += 1
151
  model_processing_text = st.empty()
152
  model_processing_text.markdown(f'**Current Task: Processing with model "{model_name}"**')
@@ -165,6 +181,7 @@ def process_data(uploaded_file, sens_level):
165
  elif model_name == 'LANG':
166
  df[model_name] = predict_category(df, model_name, progress_bar, repo='51-languages-classifier', profile='qanastek')
167
 
 
168
  model_progress.empty()
169
 
170
  progress_count += len(df)
 
88
  col_name = re.sub(r'_(.*)', r'_txt', model_name)
89
  model = SetFitModel.from_pretrained(profile+"/"+repo)
90
  model.to(device)
91
+ # Get tokenizer from the model
92
+ tokenizer = model.model_body.tokenizer
93
  else:
94
  col_name = 'scope_txt'
95
+ model = pipeline("text-classification",
96
+ model=profile+"/"+repo,
97
+ device=device,
98
+ return_all_scores=multilabel,
99
+ truncation=True,
100
+ max_length=512)
101
  predictions = []
102
  total = len(df)
103
  for i, text in enumerate(df[col_name]):
104
+ try:
105
+ if model_name in model_names_sf:
106
+ # Truncate text for SetFit models
107
+ encoded = tokenizer(text, truncation=True, max_length=512)
108
+ truncated_text = tokenizer.decode(encoded['input_ids'])
109
+ prediction = model(truncated_text)
110
+ predictions.append(0 if prediction == 'NEGATIVE' else 1)
111
+ else:
112
+ prediction = model(text)
113
+ if model_name == 'ADAPMIT':
114
+ predictions.append(re.sub('Label$', '', prediction[0]['label']))
115
+ elif model_name == 'SECTOR':
116
+ predictions.append(extract_predicted_labels(prediction[0], threshold=0.5))
117
+ elif model_name == 'LANG':
118
+ predictions.append(prediction[0]['label'])
119
+ except Exception as e:
120
+ logger.error(f"Error processing sample {df['id'][i]}: {str(e)}")
121
+ st.error("Application Error. Please contact support.")
122
  # Update progress bar with each iteration
123
  progress = (i + 1) / total
124
  progress_bar.progress(progress)
 
125
  return predictions
126
 
127
 
128
  # Main function to process data
129
  def process_data(uploaded_file, sens_level):
130
  df = pd.read_excel(uploaded_file)
131
+ logger.info(f"Data import successful")
132
  # Rename columns
133
  df.rename(columns={
134
  'id': 'id',
 
162
  step_count = 0
163
  total_steps = len(model_names)
164
  for model_name in model_names:
165
+ logger.info(f"Loading: {model_name}")
166
  step_count += 1
167
  model_processing_text = st.empty()
168
  model_processing_text.markdown(f'**Current Task: Processing with model "{model_name}"**')
 
181
  elif model_name == 'LANG':
182
  df[model_name] = predict_category(df, model_name, progress_bar, repo='51-languages-classifier', profile='qanastek')
183
 
184
+ logger.info(f"Completed: {model_name}")
185
  model_progress.empty()
186
 
187
  progress_count += len(df)