Kr08 commited on
Commit
e8ce33d
·
verified ·
1 Parent(s): e499054

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -19,21 +19,58 @@ logging.basicConfig(
19
  logger = logging.getLogger(__name__)
20
 
21
  def load_qa_model():
22
- """Load question-answering model"""
23
  try:
24
- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
25
- qa_pipeline = pipeline(
26
- "text-generation",
27
- model="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4",
28
- model_kwargs={"torch_dtype": torch.bfloat16},
 
 
 
 
 
 
 
29
  device_map="auto",
 
 
 
 
30
  use_auth_token=os.getenv("HF_TOKEN")
31
  )
 
 
 
 
 
 
 
 
 
32
  return qa_pipeline
 
33
  except Exception as e:
34
  logger.error(f"Failed to load Q&A model: {str(e)}")
35
  return None
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def load_summarization_model():
38
  """Load summarization model"""
39
  try:
 
19
  logger = logging.getLogger(__name__)
20
 
21
  def load_qa_model():
22
+ """Load question-answering model with support for long input contexts."""
23
  try:
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM
25
+
26
+ model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4"
27
+
28
+ # Load tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
30
+ tokenizer.model_max_length = 8192 # Ensure the tokenizer can handle 8192 tokens
31
+
32
+ # Load the model
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_id,
35
+ torch_dtype=torch.bfloat16,
36
  device_map="auto",
37
+ rope_scaling={
38
+ "type": "dynamic", # Ensure compatibility with long contexts
39
+ "factor": 8.0
40
+ },
41
  use_auth_token=os.getenv("HF_TOKEN")
42
  )
43
+
44
+ # Load the pipeline
45
+ qa_pipeline = pipeline(
46
+ "text-generation",
47
+ model=model,
48
+ tokenizer=tokenizer,
49
+ max_new_tokens=4096, # Adjust as needed for your use case
50
+ )
51
+
52
  return qa_pipeline
53
+
54
  except Exception as e:
55
  logger.error(f"Failed to load Q&A model: {str(e)}")
56
  return None
57
 
58
+ # def load_qa_model():
59
+ # """Load question-answering model"""
60
+ # try:
61
+ # model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
62
+ # qa_pipeline = pipeline(
63
+ # "text-generation",
64
+ # model="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4",
65
+ # model_kwargs={"torch_dtype": torch.bfloat16},
66
+ # device_map="auto",
67
+ # use_auth_token=os.getenv("HF_TOKEN")
68
+ # )
69
+ # return qa_pipeline
70
+ # except Exception as e:
71
+ # logger.error(f"Failed to load Q&A model: {str(e)}")
72
+ # return None
73
+
74
  def load_summarization_model():
75
  """Load summarization model"""
76
  try: