hienbm commited on
Commit
96b4622
1 Parent(s): 67fe8d2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -41
app.py CHANGED
@@ -26,54 +26,25 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
26
  from langchain_community.vectorstores import FAISS
27
  from langchain.schema.runnable import RunnablePassthrough
28
  from langchain_core.messages import AIMessage, HumanMessage
 
29
  from dotenv import load_dotenv
30
 
 
 
 
31
  # Get the API token from environment variable
32
  api_token = os.getenv("API_TOKEN")
33
 
34
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:15000"
35
-
36
- model_id = "google/gemma-2-9b-it"
37
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
38
 
39
- tokenizer = AutoTokenizer.from_pretrained(
40
- model_id,
41
- return_tensors="pt",
42
- padding=True,
43
- truncation=True,
44
- trust_remote_code=True,
45
- )
46
- tokenizer.pad_token = tokenizer.eos_token
47
- tokenizer.padding_side = "right"
48
-
49
- model = AutoModelForCausalLM.from_pretrained(
50
- model_id,
51
- quantization_config=quantization_config,
52
- device_map="auto",
53
- low_cpu_mem_usage=True,
54
- pad_token_id=0,
55
  )
56
- model.config.use_cache = False
57
-
58
- # Create a text generation pipeline with specific settings
59
- pipe = transformers.pipeline(
60
- task="text-generation",
61
- model=model,
62
- tokenizer=tokenizer,
63
- torch_dtype=torch.float16,
64
- device_map="auto",
65
- # do_sample=True,
66
- # top_k=10,
67
- temperature=0.0,
68
- top_p=0.9,
69
- num_return_sequences=1,
70
- eos_token_id=tokenizer.eos_token_id,
71
- max_length=4096,
72
- truncation=True,
73
- )
74
-
75
- chat_model = HuggingFacePipeline(pipeline=pipe)
76
-
77
 
78
  template = """
79
  You are a genius trader with extensive knowledge of the financial and stock markets, capable of providing deep and insightful analysis of financial stocks with remarkable accuracy.
 
26
  from langchain_community.vectorstores import FAISS
27
  from langchain.schema.runnable import RunnablePassthrough
28
  from langchain_core.messages import AIMessage, HumanMessage
29
+ from langchain_community.llms import HuggingFaceEndpoint
30
  from dotenv import load_dotenv
31
 
32
+ # Load environment variables from .env file
33
+ load_dotenv()
34
+
35
  # Get the API token from environment variable
36
  api_token = os.getenv("API_TOKEN")
37
 
38
+ # Define the repository ID and task
39
+ repo_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
40
+ task = "text-generation"
 
41
 
42
+ # Initialize the Hugging Face Endpoint
43
+ chat_model = HuggingFaceEndpoint(
44
+ huggingfacehub_api_token=api_token,
45
+ repo_id=repo_id,
46
+ task=task
 
 
 
 
 
 
 
 
 
 
 
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  template = """
50
  You are a genius trader with extensive knowledge of the financial and stock markets, capable of providing deep and insightful analysis of financial stocks with remarkable accuracy.