Upyaya commited on
Commit
d6e285e
1 Parent(s): 888b76c

Removed use of varaible "init_model_required" to init model.

Browse files

Previously used a variable "init_model_required" that makes sure the model is loaded and init only once. Was getting an error " local variable 'init_model_required' referenced before assignment". So removed the use of the variable.

Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import Blip2ForConditionalGeneration
2
  from transformers import Blip2Processor
3
- from peft import PeftModel, PeftConfig
4
  import streamlit as st
5
  from PIL import Image
6
  import torch
@@ -9,22 +9,22 @@ preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perproces
9
  base_model_ckp = "/model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
10
  peft_model_ckp = "/model/blip2_peft" #PEFT model checkpoint path
11
 
12
- init_model_required = True
13
  processor = None
14
  model = None
15
 
16
  def init_model():
17
 
18
- if init_model_required:
19
 
20
- #Preprocess input
21
- processor = Blip2Processor.from_pretrained(preprocess_ckp)
22
 
23
- #Model
24
- model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto")
25
- model = PeftModel.from_pretrained(model, peft_model_ckp)
26
 
27
- init_model_required = False
28
 
29
 
30
 
 
1
  from transformers import Blip2ForConditionalGeneration
2
  from transformers import Blip2Processor
3
+ from peft import PeftModel
4
  import streamlit as st
5
  from PIL import Image
6
  import torch
 
9
  base_model_ckp = "/model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
10
  peft_model_ckp = "/model/blip2_peft" #PEFT model checkpoint path
11
 
12
+ #init_model_required = True
13
  processor = None
14
  model = None
15
 
16
  def init_model():
17
 
18
+ #if init_model_required:
19
 
20
+ #Preprocess input
21
+ processor = Blip2Processor.from_pretrained(preprocess_ckp)
22
 
23
+ #Model
24
+ model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto")
25
+ model = PeftModel.from_pretrained(model, peft_model_ckp)
26
 
27
+ #init_model_required = False
28
 
29
 
30