mery22 commited on
Commit
29ef5c3
1 Parent(s): c63fc8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -1
app.py CHANGED
@@ -31,12 +31,38 @@ model_config = transformers.AutoConfig.from_pretrained(
31
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
32
  tokenizer.pad_token = tokenizer.eos_token
33
  tokenizer.padding_side = "right"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  #############################################################
36
  # Load pre-trained config
37
  #################################################################
38
  model = AutoModelForCausalLM.from_pretrained(
39
- "mistralai/Mistral-7B-Instruct-v0.1"
40
  )
41
  # Connect query to FAISS index using a retriever
42
  retriever = db.as_retriever(
 
31
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
32
  tokenizer.pad_token = tokenizer.eos_token
33
  tokenizer.padding_side = "right"
34
+ #################################################################
35
+ # bitsandbytes parameters
36
+ #################################################################
37
+
38
+ # Activate 4-bit precision base model loading
39
+ use_4bit = True
40
+
41
+ # Compute dtype for 4-bit base models
42
+ bnb_4bit_compute_dtype = "float16"
43
+
44
+ # Quantization type (fp4 or nf4)
45
+ bnb_4bit_quant_type = "nf4"
46
+
47
+ # Activate nested quantization for 4-bit base models (double quantization)
48
+ use_nested_quant = False
49
+ #################################################################
50
+ # Set up quantization config
51
+ #################################################################
52
+ compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
53
+
54
+ bnb_config = BitsAndBytesConfig(
55
+ load_in_4bit=use_4bit,
56
+ bnb_4bit_quant_type=bnb_4bit_quant_type,
57
+ bnb_4bit_compute_dtype=compute_dtype,
58
+ bnb_4bit_use_double_quant=use_nested_quant,
59
+ )
60
 
61
  #############################################################
62
  # Load pre-trained config
63
  #################################################################
64
  model = AutoModelForCausalLM.from_pretrained(
65
+ "mistralai/Mistral-7B-Instruct-v0.1",quantization_config=bnb_config,
66
  )
67
  # Connect query to FAISS index using a retriever
68
  retriever = db.as_retriever(