bingnoi commited on
Commit
948d7b6
1 Parent(s): ecca75f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -1,4 +1,9 @@
1
  # coding=utf-8
 
 
 
 
 
2
  from src.logger import LoggerFactory
3
  from src.prompt_concat import GetManualTestSamples, CreateTestDataset
4
  from src.utils import decode_csv_to_json, load_json, save_to_json
@@ -23,12 +28,18 @@ import spaces
23
  logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
24
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
25
 
26
- model_path = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
27
- character_path = "./character"
28
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
29
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto",
 
30
  trust_remote_code=True)
31
 
 
 
 
 
 
32
  # logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
33
  # warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
34
 
 
1
  # coding=utf-8
2
+ from typing import Dict
3
+ from typing import List
4
+ from typing import Tuple
5
+ from typing import Union
6
+ from pathlib import Path
7
  from src.logger import LoggerFactory
8
  from src.prompt_concat import GetManualTestSamples, CreateTestDataset
9
  from src.utils import decode_csv_to_json, load_json, save_to_json
 
28
  logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
29
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
30
 
31
+ MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character')
32
+ TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
35
+ model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto",
36
  trust_remote_code=True)
37
 
38
+ character_path = "./character"
39
+
40
+ def _resolve_path(path: Union[str, Path]) -> Path:
41
+ return Path(path).expanduser().resolve()
42
+
43
  # logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
44
  # warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
45