Kr08 commited on
Commit
9bc426c
1 Parent(s): e8ce33d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -19,7 +19,7 @@ logging.basicConfig(
19
  logger = logging.getLogger(__name__)
20
 
21
  def load_qa_model():
22
- """Load question-answering model with support for long input contexts."""
23
  try:
24
  from transformers import AutoTokenizer, AutoModelForCausalLM
25
 
@@ -27,26 +27,26 @@ def load_qa_model():
27
 
28
  # Load tokenizer
29
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
30
- tokenizer.model_max_length = 8192 # Ensure the tokenizer can handle 8192 tokens
31
 
32
- # Load the model
33
  model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
  torch_dtype=torch.bfloat16,
36
  device_map="auto",
37
  rope_scaling={
38
- "type": "dynamic", # Ensure compatibility with long contexts
39
- "factor": 8.0
40
  },
41
  use_auth_token=os.getenv("HF_TOKEN")
42
  )
43
 
44
- # Load the pipeline
45
  qa_pipeline = pipeline(
46
  "text-generation",
47
  model=model,
48
  tokenizer=tokenizer,
49
- max_new_tokens=4096, # Adjust as needed for your use case
50
  )
51
 
52
  return qa_pipeline
@@ -55,6 +55,7 @@ def load_qa_model():
55
  logger.error(f"Failed to load Q&A model: {str(e)}")
56
  return None
57
 
 
58
  # def load_qa_model():
59
  # """Load question-answering model"""
60
  # try:
 
19
  logger = logging.getLogger(__name__)
20
 
21
  def load_qa_model():
22
+ """Load question-answering model with long context support."""
23
  try:
24
  from transformers import AutoTokenizer, AutoModelForCausalLM
25
 
 
27
 
28
  # Load tokenizer
29
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.getenv("HF_TOKEN"))
30
+ tokenizer.model_max_length = 8192 # Configure tokenizer for long inputs
31
 
32
+ # Load the model with simplified rope_scaling configuration
33
  model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
  torch_dtype=torch.bfloat16,
36
  device_map="auto",
37
  rope_scaling={
38
+ "type": "dynamic", # Simplified type as expected by the model
39
+ "factor": 8.0 # Scaling factor to support longer contexts
40
  },
41
  use_auth_token=os.getenv("HF_TOKEN")
42
  )
43
 
44
+ # Initialize the pipeline
45
  qa_pipeline = pipeline(
46
  "text-generation",
47
  model=model,
48
  tokenizer=tokenizer,
49
+ max_new_tokens=256, # Limit generation as needed
50
  )
51
 
52
  return qa_pipeline
 
55
  logger.error(f"Failed to load Q&A model: {str(e)}")
56
  return None
57
 
58
+
59
  # def load_qa_model():
60
  # """Load question-answering model"""
61
  # try: