hayas commited on
Commit
aad286b
·
1 Parent(s): d0d7843
.pre-commit-config.yaml CHANGED
@@ -13,18 +13,15 @@ repos:
13
  args: ["--fix=lf"]
14
  - id: requirements-txt-fixer
15
  - id: trailing-whitespace
16
- - repo: https://github.com/myint/docformatter
17
- rev: v1.7.5
18
  hooks:
19
- - id: docformatter
20
- args: ["--in-place"]
21
- - repo: https://github.com/pycqa/isort
22
- rev: 5.13.2
23
- hooks:
24
- - id: isort
25
- args: ["--profile", "black"]
26
  - repo: https://github.com/pre-commit/mirrors-mypy
27
- rev: v1.12.0
28
  hooks:
29
  - id: mypy
30
  args: ["--ignore-missing-imports"]
@@ -35,18 +32,8 @@ repos:
35
  "types-PyYAML",
36
  "types-pytz",
37
  ]
38
- - repo: https://github.com/psf/black
39
- rev: 24.10.0
40
- hooks:
41
- - id: black
42
- language_version: python3.10
43
- args: ["--line-length", "119"]
44
- - repo: https://github.com/charliermarsh/ruff-pre-commit
45
- rev: v0.7.0
46
- hooks:
47
- - id: ruff
48
  - repo: https://github.com/kynan/nbstripout
49
- rev: 0.7.1
50
  hooks:
51
  - id: nbstripout
52
  args:
@@ -55,7 +42,7 @@ repos:
55
  "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
56
  ]
57
  - repo: https://github.com/nbQA-dev/nbQA
58
- rev: 1.8.7
59
  hooks:
60
  - id: nbqa-black
61
  - id: nbqa-pyupgrade
 
13
  args: ["--fix=lf"]
14
  - id: requirements-txt-fixer
15
  - id: trailing-whitespace
16
+ - repo: https://github.com/astral-sh/ruff-pre-commit
17
+ rev: v0.8.4
18
  hooks:
19
+ - id: ruff
20
+ args: ["--fix"]
21
+ - id: ruff-format
22
+ args: ["--line-length", "119"]
 
 
 
23
  - repo: https://github.com/pre-commit/mirrors-mypy
24
+ rev: v1.14.0
25
  hooks:
26
  - id: mypy
27
  args: ["--ignore-missing-imports"]
 
32
  "types-PyYAML",
33
  "types-pytz",
34
  ]
 
 
 
 
 
 
 
 
 
 
35
  - repo: https://github.com/kynan/nbstripout
36
+ rev: 0.8.1
37
  hooks:
38
  - id: nbstripout
39
  args:
 
42
  "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
43
  ]
44
  - repo: https://github.com/nbQA-dev/nbQA
45
+ rev: 1.9.1
46
  hooks:
47
  - id: nbqa-black
48
  - id: nbqa-pyupgrade
.vscode/extensions.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "recommendations": [
3
+ "ms-python.python",
4
+ "charliermarsh.ruff",
5
+ "streetsidesoftware.code-spell-checker",
6
+ "tamasfe.even-better-toml"
7
+ ]
8
+ }
.vscode/settings.json CHANGED
@@ -2,25 +2,20 @@
2
  "editor.formatOnSave": true,
3
  "files.insertFinalNewline": false,
4
  "[python]": {
5
- "editor.defaultFormatter": "ms-python.black-formatter",
6
  "editor.formatOnType": true,
7
  "editor.codeActionsOnSave": {
 
8
  "source.organizeImports": "explicit"
9
  }
10
  },
11
  "[jupyter]": {
12
  "files.insertFinalNewline": false
13
  },
14
- "black-formatter.args": [
15
- "--line-length=119"
16
- ],
17
- "isort.args": ["--profile", "black"],
18
- "flake8.args": [
19
- "--max-line-length=119"
20
- ],
21
- "ruff.lint.args": [
22
- "--line-length=119"
23
- ],
24
  "notebook.output.scrolling": true,
25
- "notebook.formatOnCellExecution": true
 
 
 
 
26
  }
 
2
  "editor.formatOnSave": true,
3
  "files.insertFinalNewline": false,
4
  "[python]": {
5
+ "editor.defaultFormatter": "charliermarsh.ruff",
6
  "editor.formatOnType": true,
7
  "editor.codeActionsOnSave": {
8
+ "source.fixAll.ruff": "explicit",
9
  "source.organizeImports": "explicit"
10
  }
11
  },
12
  "[jupyter]": {
13
  "files.insertFinalNewline": false
14
  },
 
 
 
 
 
 
 
 
 
 
15
  "notebook.output.scrolling": true,
16
+ "notebook.formatOnCellExecution": true,
17
+ "notebook.formatOnSave.enabled": true,
18
+ "notebook.codeActionsOnSave": {
19
+ "source.organizeImports": "explicit"
20
+ }
21
  }
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
4
  colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: purple
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python
2
 
3
  import os
 
4
  from threading import Thread
5
- from typing import Iterator
6
 
7
  import gradio as gr
8
  import spaces
@@ -46,23 +46,24 @@ PROMPT_DICT = {
46
 
47
 
48
  def create_prompt(instruction: str, input_text: str | None = None) -> str:
49
- """Generates a prompt based on the given instruction and an optional input.
 
50
  If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
51
  If no input is provided, it uses the 'prompt_no_input' template.
52
 
53
  Args:
54
  instruction (str): The instruction describing the task.
55
- input_text (str, optional): Additional input providing context for the task. Default is None.
56
 
57
  Returns:
58
  str: The generated prompt.
 
59
  """
60
  if input_text:
61
  # Use the 'prompt_input' template when additional input is provided
62
  return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text)
63
- else:
64
- # Use the 'prompt_no_input' template when no additional input is provided
65
- return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
66
 
67
 
68
  @spaces.GPU
@@ -80,7 +81,8 @@ def run(
80
  prompt = create_prompt(instruction, input_text)
81
  input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
82
  if input_ids.shape[-1] > MAX_INPUT_TOKENS:
83
- raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})")
 
84
 
85
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
86
  generate_kwargs = dict(
@@ -138,7 +140,7 @@ with gr.Blocks(css_paths="style.css") as demo:
138
  "以下のトピックに関する詳細な情報を提供してください。",
139
  "夢オチとは何かについて教えてください。",
140
  ],
141
- ["暴れん坊将軍って誰のことですか?", ""],
142
  ],
143
  inputs=[instruction, input_text],
144
  outputs=output,
 
1
  #!/usr/bin/env python
2
 
3
  import os
4
+ from collections.abc import Iterator
5
  from threading import Thread
 
6
 
7
  import gradio as gr
8
  import spaces
 
46
 
47
 
48
  def create_prompt(instruction: str, input_text: str | None = None) -> str:
49
+ """Generate a prompt based on the given instruction and an optional input.
50
+
51
  If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
52
  If no input is provided, it uses the 'prompt_no_input' template.
53
 
54
  Args:
55
  instruction (str): The instruction describing the task.
56
+ input_text (str | None): Additional input providing context for the task. Defaults to None.
57
 
58
  Returns:
59
  str: The generated prompt.
60
+
61
  """
62
  if input_text:
63
  # Use the 'prompt_input' template when additional input is provided
64
  return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text)
65
+ # Use the 'prompt_no_input' template when no additional input is provided
66
+ return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
 
67
 
68
 
69
  @spaces.GPU
 
81
  prompt = create_prompt(instruction, input_text)
82
  input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
83
  if input_ids.shape[-1] > MAX_INPUT_TOKENS:
84
+ error_message = f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})"
85
+ raise gr.Error(error_message)
86
 
87
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
88
  generate_kwargs = dict(
 
140
  "以下のトピックに関する詳細な情報を提供してください。",
141
  "夢オチとは何かについて教えてください。",
142
  ],
143
+ ["暴れん坊将軍って誰のことですか?", ""], # noqa: RUF001
144
  ],
145
  inputs=[instruction, input_text],
146
  outputs=output,
pyproject.toml CHANGED
@@ -5,16 +5,52 @@ description = ""
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  dependencies = [
8
- "accelerate>=1.0.1",
9
- "bitsandbytes>=0.44.1",
10
  "blobfile>=3.0.0",
11
- "gradio>=5.1.0",
12
  "hf-transfer>=0.1.8",
13
- "protobuf>=5.28.2",
14
  "sentencepiece>=0.2.0",
15
- "setuptools>=75.2.0",
16
- "spaces>=0.30.4",
17
  "tiktoken>=0.8.0",
18
  "torch==2.4.0",
19
- "transformers>=4.45.2",
20
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  dependencies = [
8
+ "accelerate>=1.2.1",
9
+ "bitsandbytes>=0.45.0",
10
  "blobfile>=3.0.0",
11
+ "gradio>=5.9.1",
12
  "hf-transfer>=0.1.8",
13
+ "protobuf>=5.29.2",
14
  "sentencepiece>=0.2.0",
15
+ "setuptools>=75.6.0",
16
+ "spaces>=0.31.1",
17
  "tiktoken>=0.8.0",
18
  "torch==2.4.0",
19
+ "transformers>=4.47.1",
20
  ]
21
+
22
+ [tool.ruff]
23
+ line-length = 119
24
+
25
+ [tool.ruff.lint]
26
+ select = ["ALL"]
27
+ ignore = [
28
+ "COM812", # missing-trailing-comma
29
+ "D203", # one-blank-line-before-class
30
+ "D213", # multi-line-summary-second-line
31
+ "E501", # line-too-long
32
+ "SIM117", # multiple-with-statements
33
+ ]
34
+ extend-ignore = [
35
+ "D100", # undocumented-public-module
36
+ "D101", # undocumented-public-class
37
+ "D102", # undocumented-public-method
38
+ "D103", # undocumented-public-function
39
+ "D104", # undocumented-public-package
40
+ "D105", # undocumented-magic-method
41
+ "D107", # undocumented-public-init
42
+ "EM101", # raw-string-in-exception
43
+ "FBT001", # boolean-type-hint-positional-argument
44
+ "FBT002", # boolean-default-value-positional-argument
45
+ "PD901", # pandas-df-variable-name
46
+ "PGH003", # blanket-type-ignore
47
+ "PLR0913", # too-many-arguments
48
+ "PLR0915", # too-many-statements
49
+ "TRY003", # raise-vanilla-args
50
+ ]
51
+ unfixable = [
52
+ "F401", # unused-import
53
+ ]
54
+
55
+ [tool.ruff.format]
56
+ docstring-code-format = true
requirements.txt CHANGED
@@ -1,36 +1,36 @@
1
  # This file was autogenerated by uv via the following command:
2
  # uv pip compile pyproject.toml -o requirements.txt
3
- accelerate==1.0.1
4
  # via swallow-13b-instruct (pyproject.toml)
5
  aiofiles==23.2.1
6
  # via gradio
7
  annotated-types==0.7.0
8
  # via pydantic
9
- anyio==4.6.2.post1
10
  # via
11
  # gradio
12
  # httpx
13
  # starlette
14
- bitsandbytes==0.44.1
15
  # via swallow-13b-instruct (pyproject.toml)
16
  blobfile==3.0.0
17
  # via swallow-13b-instruct (pyproject.toml)
18
- certifi==2024.8.30
19
  # via
20
  # httpcore
21
  # httpx
22
  # requests
23
- charset-normalizer==3.4.0
24
  # via requests
25
- click==8.1.7
26
  # via
27
  # typer
28
  # uvicorn
29
  exceptiongroup==1.2.2
30
  # via anyio
31
- fastapi==0.115.2
32
  # via gradio
33
- ffmpy==0.4.0
34
  # via gradio
35
  filelock==3.16.1
36
  # via
@@ -39,16 +39,16 @@ filelock==3.16.1
39
  # torch
40
  # transformers
41
  # triton
42
- fsspec==2024.9.0
43
  # via
44
  # gradio-client
45
  # huggingface-hub
46
  # torch
47
- gradio==5.1.0
48
  # via
49
  # swallow-13b-instruct (pyproject.toml)
50
  # spaces
51
- gradio-client==1.4.0
52
  # via gradio
53
  h11==0.14.0
54
  # via
@@ -56,14 +56,15 @@ h11==0.14.0
56
  # uvicorn
57
  hf-transfer==0.1.8
58
  # via swallow-13b-instruct (pyproject.toml)
59
- httpcore==1.0.6
60
  # via httpx
61
- httpx==0.27.2
62
  # via
63
  # gradio
64
  # gradio-client
 
65
  # spaces
66
- huggingface-hub==0.26.0
67
  # via
68
  # accelerate
69
  # gradio
@@ -75,7 +76,7 @@ idna==3.10
75
  # anyio
76
  # httpx
77
  # requests
78
- jinja2==3.1.4
79
  # via
80
  # gradio
81
  # torch
@@ -91,9 +92,9 @@ mdurl==0.1.2
91
  # via markdown-it-py
92
  mpmath==1.3.0
93
  # via sympy
94
- networkx==3.4.1
95
  # via torch
96
- numpy==2.1.2
97
  # via
98
  # accelerate
99
  # bitsandbytes
@@ -125,15 +126,15 @@ nvidia-cusparse-cu12==12.1.0.106
125
  # torch
126
  nvidia-nccl-cu12==2.20.5
127
  # via torch
128
- nvidia-nvjitlink-cu12==12.6.77
129
  # via
130
  # nvidia-cusolver-cu12
131
  # nvidia-cusparse-cu12
132
  nvidia-nvtx-cu12==12.1.105
133
  # via torch
134
- orjson==3.10.9
135
  # via gradio
136
- packaging==24.1
137
  # via
138
  # accelerate
139
  # gradio
@@ -143,9 +144,9 @@ packaging==24.1
143
  # transformers
144
  pandas==2.2.3
145
  # via gradio
146
- pillow==10.4.0
147
  # via gradio
148
- protobuf==5.28.2
149
  # via swallow-13b-instruct (pyproject.toml)
150
  psutil==5.9.8
151
  # via
@@ -153,12 +154,12 @@ psutil==5.9.8
153
  # spaces
154
  pycryptodomex==3.21.0
155
  # via blobfile
156
- pydantic==2.9.2
157
  # via
158
  # fastapi
159
  # gradio
160
  # spaces
161
- pydantic-core==2.23.4
162
  # via pydantic
163
  pydub==0.25.1
164
  # via gradio
@@ -166,7 +167,7 @@ pygments==2.18.0
166
  # via rich
167
  python-dateutil==2.9.0.post0
168
  # via pandas
169
- python-multipart==0.0.12
170
  # via gradio
171
  pytz==2024.2
172
  # via pandas
@@ -176,7 +177,7 @@ pyyaml==6.0.2
176
  # gradio
177
  # huggingface-hub
178
  # transformers
179
- regex==2024.9.11
180
  # via
181
  # tiktoken
182
  # transformers
@@ -186,9 +187,11 @@ requests==2.32.3
186
  # spaces
187
  # tiktoken
188
  # transformers
189
- rich==13.9.2
190
  # via typer
191
- ruff==0.7.0
 
 
192
  # via gradio
193
  safetensors==0.4.5
194
  # via
@@ -198,46 +201,47 @@ semantic-version==2.10.0
198
  # via gradio
199
  sentencepiece==0.2.0
200
  # via swallow-13b-instruct (pyproject.toml)
201
- setuptools==75.2.0
202
  # via swallow-13b-instruct (pyproject.toml)
203
  shellingham==1.5.4
204
  # via typer
205
- six==1.16.0
206
  # via python-dateutil
207
  sniffio==1.3.1
208
- # via
209
- # anyio
210
- # httpx
211
- spaces==0.30.4
212
  # via swallow-13b-instruct (pyproject.toml)
213
- starlette==0.40.0
214
- # via fastapi
 
 
215
  sympy==1.13.3
216
  # via torch
217
  tiktoken==0.8.0
218
  # via swallow-13b-instruct (pyproject.toml)
219
- tokenizers==0.20.1
220
  # via transformers
221
- tomlkit==0.12.0
222
  # via gradio
223
  torch==2.4.0
224
  # via
225
  # swallow-13b-instruct (pyproject.toml)
226
  # accelerate
227
  # bitsandbytes
228
- tqdm==4.66.5
229
  # via
230
  # huggingface-hub
231
  # transformers
232
- transformers==4.45.2
233
  # via swallow-13b-instruct (pyproject.toml)
234
  triton==3.0.0
235
  # via torch
236
- typer==0.12.5
237
  # via gradio
238
  typing-extensions==4.12.2
239
  # via
240
  # anyio
 
241
  # fastapi
242
  # gradio
243
  # gradio-client
@@ -251,11 +255,11 @@ typing-extensions==4.12.2
251
  # uvicorn
252
  tzdata==2024.2
253
  # via pandas
254
- urllib3==2.2.3
255
  # via
256
  # blobfile
257
  # requests
258
- uvicorn==0.32.0
259
  # via gradio
260
- websockets==12.0
261
  # via gradio-client
 
1
  # This file was autogenerated by uv via the following command:
2
  # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.2.1
4
  # via swallow-13b-instruct (pyproject.toml)
5
  aiofiles==23.2.1
6
  # via gradio
7
  annotated-types==0.7.0
8
  # via pydantic
9
+ anyio==4.7.0
10
  # via
11
  # gradio
12
  # httpx
13
  # starlette
14
+ bitsandbytes==0.45.0
15
  # via swallow-13b-instruct (pyproject.toml)
16
  blobfile==3.0.0
17
  # via swallow-13b-instruct (pyproject.toml)
18
+ certifi==2024.12.14
19
  # via
20
  # httpcore
21
  # httpx
22
  # requests
23
+ charset-normalizer==3.4.1
24
  # via requests
25
+ click==8.1.8
26
  # via
27
  # typer
28
  # uvicorn
29
  exceptiongroup==1.2.2
30
  # via anyio
31
+ fastapi==0.115.6
32
  # via gradio
33
+ ffmpy==0.5.0
34
  # via gradio
35
  filelock==3.16.1
36
  # via
 
39
  # torch
40
  # transformers
41
  # triton
42
+ fsspec==2024.12.0
43
  # via
44
  # gradio-client
45
  # huggingface-hub
46
  # torch
47
+ gradio==5.9.1
48
  # via
49
  # swallow-13b-instruct (pyproject.toml)
50
  # spaces
51
+ gradio-client==1.5.2
52
  # via gradio
53
  h11==0.14.0
54
  # via
 
56
  # uvicorn
57
  hf-transfer==0.1.8
58
  # via swallow-13b-instruct (pyproject.toml)
59
+ httpcore==1.0.7
60
  # via httpx
61
+ httpx==0.28.1
62
  # via
63
  # gradio
64
  # gradio-client
65
+ # safehttpx
66
  # spaces
67
+ huggingface-hub==0.27.0
68
  # via
69
  # accelerate
70
  # gradio
 
76
  # anyio
77
  # httpx
78
  # requests
79
+ jinja2==3.1.5
80
  # via
81
  # gradio
82
  # torch
 
92
  # via markdown-it-py
93
  mpmath==1.3.0
94
  # via sympy
95
+ networkx==3.4.2
96
  # via torch
97
+ numpy==2.2.1
98
  # via
99
  # accelerate
100
  # bitsandbytes
 
126
  # torch
127
  nvidia-nccl-cu12==2.20.5
128
  # via torch
129
+ nvidia-nvjitlink-cu12==12.6.85
130
  # via
131
  # nvidia-cusolver-cu12
132
  # nvidia-cusparse-cu12
133
  nvidia-nvtx-cu12==12.1.105
134
  # via torch
135
+ orjson==3.10.13
136
  # via gradio
137
+ packaging==24.2
138
  # via
139
  # accelerate
140
  # gradio
 
144
  # transformers
145
  pandas==2.2.3
146
  # via gradio
147
+ pillow==11.1.0
148
  # via gradio
149
+ protobuf==5.29.2
150
  # via swallow-13b-instruct (pyproject.toml)
151
  psutil==5.9.8
152
  # via
 
154
  # spaces
155
  pycryptodomex==3.21.0
156
  # via blobfile
157
+ pydantic==2.10.4
158
  # via
159
  # fastapi
160
  # gradio
161
  # spaces
162
+ pydantic-core==2.27.2
163
  # via pydantic
164
  pydub==0.25.1
165
  # via gradio
 
167
  # via rich
168
  python-dateutil==2.9.0.post0
169
  # via pandas
170
+ python-multipart==0.0.20
171
  # via gradio
172
  pytz==2024.2
173
  # via pandas
 
177
  # gradio
178
  # huggingface-hub
179
  # transformers
180
+ regex==2024.11.6
181
  # via
182
  # tiktoken
183
  # transformers
 
187
  # spaces
188
  # tiktoken
189
  # transformers
190
+ rich==13.9.4
191
  # via typer
192
+ ruff==0.8.4
193
+ # via gradio
194
+ safehttpx==0.1.6
195
  # via gradio
196
  safetensors==0.4.5
197
  # via
 
201
  # via gradio
202
  sentencepiece==0.2.0
203
  # via swallow-13b-instruct (pyproject.toml)
204
+ setuptools==75.6.0
205
  # via swallow-13b-instruct (pyproject.toml)
206
  shellingham==1.5.4
207
  # via typer
208
+ six==1.17.0
209
  # via python-dateutil
210
  sniffio==1.3.1
211
+ # via anyio
212
+ spaces==0.31.1
 
 
213
  # via swallow-13b-instruct (pyproject.toml)
214
+ starlette==0.41.3
215
+ # via
216
+ # fastapi
217
+ # gradio
218
  sympy==1.13.3
219
  # via torch
220
  tiktoken==0.8.0
221
  # via swallow-13b-instruct (pyproject.toml)
222
+ tokenizers==0.21.0
223
  # via transformers
224
+ tomlkit==0.13.2
225
  # via gradio
226
  torch==2.4.0
227
  # via
228
  # swallow-13b-instruct (pyproject.toml)
229
  # accelerate
230
  # bitsandbytes
231
+ tqdm==4.67.1
232
  # via
233
  # huggingface-hub
234
  # transformers
235
+ transformers==4.47.1
236
  # via swallow-13b-instruct (pyproject.toml)
237
  triton==3.0.0
238
  # via torch
239
+ typer==0.15.1
240
  # via gradio
241
  typing-extensions==4.12.2
242
  # via
243
  # anyio
244
+ # bitsandbytes
245
  # fastapi
246
  # gradio
247
  # gradio-client
 
255
  # uvicorn
256
  tzdata==2024.2
257
  # via pandas
258
+ urllib3==2.3.0
259
  # via
260
  # blobfile
261
  # requests
262
+ uvicorn==0.34.0
263
  # via gradio
264
+ websockets==14.1
265
  # via gradio-client
uv.lock CHANGED
The diff for this file is too large to render. See raw diff