habdine commited on
Commit
35f1339
·
verified ·
1 Parent(s): 077f6a6

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +13 -0
  2. app.py +131 -0
  3. gitattributes +35 -0
  4. pre-commit-config.yaml +60 -0
  5. requirements.txt +236 -0
  6. style.css +11 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Atlas Chat 9B
3
+ emoji: 😻
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.0.1
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Chatbot
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Atlas Chat 9B
12
+
13
+
14
+ This is a demo of [`MBZUAI-Paris/Atlas-Chat-9B`](https://huggingface.co/MBZUAI-Paris/Atlas-Chat-9B), fine-tuned for instruction following.
15
+
16
+ """
17
+
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ model_id = "MBZUAI-Paris/Atlas-Chat-9B"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ device_map="auto",
29
+ torch_dtype=torch.bfloat16,
30
+ )
31
+ model.config.sliding_window = 4096
32
+ model.eval()
33
+
34
+
35
+ @spaces.GPU(duration=90)
36
+ def generate(
37
+ message: str,
38
+ chat_history: list[dict],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
+ ) -> Iterator[str]:
45
+ conversation = chat_history.copy()
46
+ conversation.append({"role": "user", "content": message})
47
+
48
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
49
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
50
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
51
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
+ input_ids = input_ids.to(model.device)
53
+
54
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
55
+ generate_kwargs = dict(
56
+ {"input_ids": input_ids},
57
+ streamer=streamer,
58
+ max_new_tokens=max_new_tokens,
59
+ do_sample=True,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ temperature=temperature,
63
+ num_beams=1,
64
+ repetition_penalty=repetition_penalty,
65
+ )
66
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
67
+ t.start()
68
+
69
+ outputs = []
70
+ for text in streamer:
71
+ outputs.append(text)
72
+ yield "".join(outputs)
73
+
74
+
75
+ chat_interface = gr.ChatInterface(
76
+ fn=generate,
77
+ additional_inputs=[
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ minimum=1,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ value=DEFAULT_MAX_NEW_TOKENS,
84
+ ),
85
+ gr.Slider(
86
+ label="Temperature",
87
+ minimum=0.1,
88
+ maximum=4.0,
89
+ step=0.1,
90
+ value=0.6,
91
+ ),
92
+ gr.Slider(
93
+ label="Top-p (nucleus sampling)",
94
+ minimum=0.05,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ value=0.9,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-k",
101
+ minimum=1,
102
+ maximum=1000,
103
+ step=1,
104
+ value=50,
105
+ ),
106
+ gr.Slider(
107
+ label="Repetition penalty",
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ value=1.2,
112
+ ),
113
+ ],
114
+ stop_btn=None,
115
+ examples=[
116
+ ['شكون لي صنعك؟'],
117
+ ["شنو كيتسمى المنتخب المغربي ؟"],
118
+ ["أشنو كايمييز المملكة المغربية."],
119
+ ["ترجم للدارجة:\nAtlas Chat is the first open source large language model that talks in Darija."],
120
+ ],
121
+ cache_examples=False,
122
+ type="messages",
123
+ )
124
+
125
+ with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
126
+ gr.Markdown(DESCRIPTION)
127
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
128
+ chat_interface.render()
129
+
130
+ if __name__ == "__main__":
131
+ demo.queue(max_size=20).launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
requirements.txt ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.0.0
4
+ # via gemma-2-9b-it (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.6.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2024.8.30
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.3.2
20
+ # via requests
21
+ click==8.1.7
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.0
28
+ # via gradio
29
+ ffmpy==0.4.0
30
+ # via gradio
31
+ filelock==3.16.1
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2024.9.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.0.1
43
+ # via
44
+ # gemma-2-9b-it (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.4.0
47
+ # via gradio
48
+ h11==0.14.0
49
+ # via
50
+ # httpcore
51
+ # uvicorn
52
+ hf-transfer==0.1.8
53
+ # via gemma-2-9b-it (pyproject.toml)
54
+ httpcore==1.0.5
55
+ # via httpx
56
+ httpx==0.27.2
57
+ # via
58
+ # gradio
59
+ # gradio-client
60
+ # spaces
61
+ huggingface-hub==0.25.1
62
+ # via
63
+ # accelerate
64
+ # gradio
65
+ # gradio-client
66
+ # tokenizers
67
+ # transformers
68
+ idna==3.10
69
+ # via
70
+ # anyio
71
+ # httpx
72
+ # requests
73
+ jinja2==3.1.4
74
+ # via
75
+ # gradio
76
+ # torch
77
+ markdown-it-py==3.0.0
78
+ # via rich
79
+ markupsafe==2.1.5
80
+ # via
81
+ # gradio
82
+ # jinja2
83
+ mdurl==0.1.2
84
+ # via markdown-it-py
85
+ mpmath==1.3.0
86
+ # via sympy
87
+ networkx==3.3
88
+ # via torch
89
+ numpy==2.1.1
90
+ # via
91
+ # accelerate
92
+ # gradio
93
+ # pandas
94
+ # transformers
95
+ nvidia-cublas-cu12==12.1.3.1
96
+ # via
97
+ # nvidia-cudnn-cu12
98
+ # nvidia-cusolver-cu12
99
+ # torch
100
+ nvidia-cuda-cupti-cu12==12.1.105
101
+ # via torch
102
+ nvidia-cuda-nvrtc-cu12==12.1.105
103
+ # via torch
104
+ nvidia-cuda-runtime-cu12==12.1.105
105
+ # via torch
106
+ nvidia-cudnn-cu12==9.1.0.70
107
+ # via torch
108
+ nvidia-cufft-cu12==11.0.2.54
109
+ # via torch
110
+ nvidia-curand-cu12==10.3.2.106
111
+ # via torch
112
+ nvidia-cusolver-cu12==11.4.5.107
113
+ # via torch
114
+ nvidia-cusparse-cu12==12.1.0.106
115
+ # via
116
+ # nvidia-cusolver-cu12
117
+ # torch
118
+ nvidia-nccl-cu12==2.20.5
119
+ # via torch
120
+ nvidia-nvjitlink-cu12==12.6.68
121
+ # via
122
+ # nvidia-cusolver-cu12
123
+ # nvidia-cusparse-cu12
124
+ nvidia-nvtx-cu12==12.1.105
125
+ # via torch
126
+ orjson==3.10.7
127
+ # via gradio
128
+ packaging==24.1
129
+ # via
130
+ # accelerate
131
+ # gradio
132
+ # gradio-client
133
+ # huggingface-hub
134
+ # spaces
135
+ # transformers
136
+ pandas==2.2.3
137
+ # via gradio
138
+ pillow==10.4.0
139
+ # via gradio
140
+ psutil==5.9.8
141
+ # via
142
+ # accelerate
143
+ # spaces
144
+ pydantic==2.9.2
145
+ # via
146
+ # fastapi
147
+ # gradio
148
+ # spaces
149
+ pydantic-core==2.23.4
150
+ # via pydantic
151
+ pydub==0.25.1
152
+ # via gradio
153
+ pygments==2.18.0
154
+ # via rich
155
+ python-dateutil==2.9.0.post0
156
+ # via pandas
157
+ python-multipart==0.0.12
158
+ # via gradio
159
+ pytz==2024.2
160
+ # via pandas
161
+ pyyaml==6.0.2
162
+ # via
163
+ # accelerate
164
+ # gradio
165
+ # huggingface-hub
166
+ # transformers
167
+ regex==2024.9.11
168
+ # via transformers
169
+ requests==2.32.3
170
+ # via
171
+ # huggingface-hub
172
+ # spaces
173
+ # transformers
174
+ rich==13.8.1
175
+ # via typer
176
+ ruff==0.6.8
177
+ # via gradio
178
+ safetensors==0.4.5
179
+ # via
180
+ # accelerate
181
+ # transformers
182
+ semantic-version==2.10.0
183
+ # via gradio
184
+ shellingham==1.5.4
185
+ # via typer
186
+ six==1.16.0
187
+ # via python-dateutil
188
+ sniffio==1.3.1
189
+ # via
190
+ # anyio
191
+ # httpx
192
+ spaces==0.30.3
193
+ # via gemma-2-9b-it (pyproject.toml)
194
+ starlette==0.38.6
195
+ # via fastapi
196
+ sympy==1.13.3
197
+ # via torch
198
+ tokenizers==0.20.0
199
+ # via transformers
200
+ tomlkit==0.12.0
201
+ # via gradio
202
+ torch==2.4.0
203
+ # via
204
+ # gemma-2-9b-it (pyproject.toml)
205
+ # accelerate
206
+ tqdm==4.66.5
207
+ # via
208
+ # huggingface-hub
209
+ # transformers
210
+ transformers==4.45.2
211
+ # via gemma-2-9b-it (pyproject.toml)
212
+ triton==3.0.0
213
+ # via torch
214
+ typer==0.12.5
215
+ # via gradio
216
+ typing-extensions==4.12.2
217
+ # via
218
+ # anyio
219
+ # fastapi
220
+ # gradio
221
+ # gradio-client
222
+ # huggingface-hub
223
+ # pydantic
224
+ # pydantic-core
225
+ # spaces
226
+ # torch
227
+ # typer
228
+ # uvicorn
229
+ tzdata==2024.2
230
+ # via pandas
231
+ urllib3==2.2.3
232
+ # via requests
233
+ uvicorn==0.31.0
234
+ # via gradio
235
+ websockets==12.0
236
+ # via gradio-client
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }