shenchucheng commited on
Commit
d0921c0
1 Parent(s): faf9dc6

add local storage

Browse files
Files changed (5) hide show
  1. .dockerignore +2 -1
  2. .gitignore +1 -0
  3. app.py +62 -54
  4. config/config.yaml +14 -0
  5. software_company.py +26 -17
.dockerignore CHANGED
@@ -5,4 +5,5 @@ workspace
5
  dist
6
  data
7
  geckodriver.log
8
- logs
 
 
5
  dist
6
  data
7
  geckodriver.log
8
+ logs
9
+ storage
.gitignore CHANGED
@@ -169,3 +169,4 @@ output.wav
169
  output
170
  tmp.png
171
 
 
 
169
  output
170
  tmp.png
171
 
172
+ storage/*
app.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
  from collections import deque
7
  import contextlib
8
  from functools import partial
 
9
  import urllib.parse
10
  from datetime import datetime
11
  import uuid
@@ -154,58 +155,60 @@ async def create_message(req_model: NewMsg, request: Request):
154
  """
155
  Session message stream
156
  """
157
- config = {k.upper(): v for k, v in req_model.config.items()}
158
- set_context(config, uuid.uuid4().hex)
159
-
160
- msg_queue = deque()
161
- CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None
162
-
163
- role = SoftwareCompany()
164
- role.recv(message=Message(content=req_model.query))
165
- answer = MessageJsonModel(
166
- steps=[
167
- Sentences(
168
- contents=[
169
- Sentence(type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True)
170
- ],
171
- status=MessageStatus.COMPLETE.value,
172
- )
173
- ],
174
- qa_type=QueryAnswerType.Answer.value,
175
- )
176
-
177
- tc_id = 0
178
-
179
- while True:
180
- tc_id += 1
181
- if request and await request.is_disconnected():
182
- return
183
- think_result: RoleRun = await role.think()
184
- if not think_result: # End of conversion
185
- break
186
-
187
- think_act_prompt = ThinkActPrompt(role=think_result.role.profile)
188
- think_act_prompt.update_think(tc_id, think_result)
189
- yield think_act_prompt.prompt + "\n\n"
190
- task = asyncio.create_task(role.act())
191
-
192
- while not await request.is_disconnected():
193
- if msg_queue:
194
- think_act_prompt.update_act(msg_queue.pop(), False)
195
- yield think_act_prompt.prompt + "\n\n"
196
- continue
197
-
198
- if task.done():
199
  break
200
 
201
- await asyncio.sleep(0.5)
 
 
 
202
 
203
- act_result = await task
204
- think_act_prompt.update_act(act_result)
205
- yield think_act_prompt.prompt + "\n\n"
206
- answer.add_think_act(think_act_prompt)
207
- yield answer.prompt + "\n\n" # Notify the front-end that the message is complete.
208
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  default_llm_stream_log = partial(print, end="")
211
 
@@ -236,10 +239,11 @@ class ChatHandler:
236
  app = FastAPI()
237
 
238
  app.mount(
239
- "/static",
240
- StaticFiles(directory="./static/", check_dir=True),
241
  name="static",
242
  )
 
243
  app.add_api_route(
244
  "/api/messages",
245
  endpoint=ChatHandler.create_message,
@@ -250,13 +254,17 @@ app.add_api_route(
250
 
251
  @app.get("/{catch_all:path}")
252
  async def catch_all(request: Request):
253
- if request.url.path == "/":
254
- return RedirectResponse(url="/static/index.html")
255
  if request.url.path.startswith("/api"):
256
  raise HTTPException(status_code=404)
257
 
258
- new_path = f"/static{request.url.path}"
259
- return RedirectResponse(url=new_path)
 
 
 
 
 
 
260
 
261
 
262
  set_llm_stream_logfunc(llm_stream_log)
 
6
  from collections import deque
7
  import contextlib
8
  from functools import partial
9
+ import shutil
10
  import urllib.parse
11
  from datetime import datetime
12
  import uuid
 
155
  """
156
  Session message stream
157
  """
158
+ try:
159
+ config = {k.upper(): v for k, v in req_model.config.items()}
160
+ set_context(config, uuid.uuid4().hex)
161
+
162
+ msg_queue = deque()
163
+ CONFIG.LLM_STREAM_LOG = lambda x: msg_queue.appendleft(x) if x else None
164
+
165
+ role = SoftwareCompany()
166
+ role.recv(message=Message(content=req_model.query))
167
+ answer = MessageJsonModel(
168
+ steps=[
169
+ Sentences(
170
+ contents=[
171
+ Sentence(type=SentenceType.TEXT.value, value=SentenceValue(answer=req_model.query), is_finished=True)
172
+ ],
173
+ status=MessageStatus.COMPLETE.value,
174
+ )
175
+ ],
176
+ qa_type=QueryAnswerType.Answer.value,
177
+ )
178
+
179
+ tc_id = 0
180
+
181
+ while True:
182
+ tc_id += 1
183
+ if request and await request.is_disconnected():
184
+ return
185
+ think_result: RoleRun = await role.think()
186
+ if not think_result: # End of conversion
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  break
188
 
189
+ think_act_prompt = ThinkActPrompt(role=think_result.role.profile)
190
+ think_act_prompt.update_think(tc_id, think_result)
191
+ yield think_act_prompt.prompt + "\n\n"
192
+ task = asyncio.create_task(role.act())
193
 
194
+ while not await request.is_disconnected():
195
+ if msg_queue:
196
+ think_act_prompt.update_act(msg_queue.pop(), False)
197
+ yield think_act_prompt.prompt + "\n\n"
198
+ continue
199
 
200
+ if task.done():
201
+ break
202
+
203
+ await asyncio.sleep(0.5)
204
+
205
+ act_result = await task
206
+ think_act_prompt.update_act(act_result)
207
+ yield think_act_prompt.prompt + "\n\n"
208
+ answer.add_think_act(think_act_prompt)
209
+ yield answer.prompt + "\n\n" # Notify the front-end that the message is complete.
210
+ finally:
211
+ shutil.rmtree(CONFIG.WORKSPACE_PATH)
212
 
213
  default_llm_stream_log = partial(print, end="")
214
 
 
239
  app = FastAPI()
240
 
241
  app.mount(
242
+ "/storage",
243
+ StaticFiles(directory="./storage/"),
244
  name="static",
245
  )
246
+
247
  app.add_api_route(
248
  "/api/messages",
249
  endpoint=ChatHandler.create_message,
 
254
 
255
  @app.get("/{catch_all:path}")
256
  async def catch_all(request: Request):
 
 
257
  if request.url.path.startswith("/api"):
258
  raise HTTPException(status_code=404)
259
 
260
+ return RedirectResponse(url="/index.html")
261
+
262
+
263
+ app.mount(
264
+ "/",
265
+ StaticFiles(directory="./static/", html=True),
266
+ name="static",
267
+ )
268
 
269
 
270
  set_llm_stream_logfunc(llm_stream_log)
config/config.yaml CHANGED
@@ -120,3 +120,17 @@ RPM: 10
120
  # PROMPT_FORMAT: json #json or markdown
121
 
122
  DISABLE_LLM_PROVIDER_CHECK: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # PROMPT_FORMAT: json #json or markdown
121
 
122
  DISABLE_LLM_PROVIDER_CHECK: true
123
+ STORAGE_TYPE: local # local / s3
124
+
125
+ # for local storage
126
+ LOCAL_ROOT: "storage"
127
+ LOCAL_BASE_URL: "storage"
128
+
129
+
130
+ # for s3 storage
131
+
132
+ # S3_ACCESS_KEY: ""
133
+ # S3_SECRET_KEY: ""
134
+ # S3_ENDPOINT_URL: ""
135
+ # S3_BUCKET: ""
136
+ # S3_SECURE: false
software_company.py CHANGED
@@ -62,7 +62,7 @@ class PackProject(Action):
62
  chunks = []
63
  async for chunk in AioZipStream(files, chunksize=32768).stream():
64
  chunks.append(chunk)
65
- return await upload_to_s3(b"".join(chunks), key)
66
 
67
 
68
  class SoftwareCompany(Role):
@@ -345,22 +345,31 @@ class SoftwareCompany(Role):
345
  async def upload_file_to_s3(filepath: str, key: str):
346
  async with aiofiles.open(filepath, "rb") as f:
347
  content = await f.read()
348
- return await upload_to_s3(content, key)
349
-
350
-
351
- async def upload_to_s3(content: bytes, key: str):
352
- session = get_session()
353
- async with session.create_client(
354
- "s3",
355
- aws_secret_access_key=CONFIG.get("S3_SECRET_KEY"),
356
- aws_access_key_id=CONFIG.get("S3_ACCESS_KEY"),
357
- endpoint_url=CONFIG.get("S3_ENDPOINT_URL"),
358
- use_ssl=CONFIG.get("S3_SECURE"),
359
- ) as client:
360
- # upload object to amazon s3
361
- bucket = CONFIG.get("S3_BUCKET")
362
- await client.put_object(Bucket=bucket, Key=key, Body=content)
363
- return f"{CONFIG.get('S3_ENDPOINT_URL')}/{bucket}/{key}"
 
 
 
 
 
 
 
 
 
364
 
365
 
366
  async def main(idea, **kwargs):
 
62
  chunks = []
63
  async for chunk in AioZipStream(files, chunksize=32768).stream():
64
  chunks.append(chunk)
65
+ return await get_download_url(b"".join(chunks), key)
66
 
67
 
68
  class SoftwareCompany(Role):
 
345
  async def upload_file_to_s3(filepath: str, key: str):
346
  async with aiofiles.open(filepath, "rb") as f:
347
  content = await f.read()
348
+ return await get_download_url(content, key)
349
+
350
+
351
+ async def get_download_url(content: bytes, key: str) -> str:
352
+ if CONFIG.get("STORAGE_TYPE") == "S3":
353
+ session = get_session()
354
+ async with session.create_client(
355
+ "s3",
356
+ aws_secret_access_key=CONFIG.get("S3_SECRET_KEY"),
357
+ aws_access_key_id=CONFIG.get("S3_ACCESS_KEY"),
358
+ endpoint_url=CONFIG.get("S3_ENDPOINT_URL"),
359
+ use_ssl=CONFIG.get("S3_SECURE"),
360
+ ) as client:
361
+ # upload object to amazon s3
362
+ bucket = CONFIG.get("S3_BUCKET")
363
+ await client.put_object(Bucket=bucket, Key=key, Body=content)
364
+ return f"{CONFIG.get('S3_ENDPOINT_URL')}/{bucket}/{key}"
365
+ else:
366
+ storage = CONFIG.get("LOCAL_ROOT", "storage")
367
+ base_url = CONFIG.get("LOCAL_BASE_URL", "storage")
368
+ filepath = Path(storage) / key
369
+ filepath.parent.mkdir(exist_ok=True, parents=True)
370
+ async with aiofiles.open(filepath, "wb") as f:
371
+ await f.write(content)
372
+ return f"{base_url}/{key}"
373
 
374
 
375
  async def main(idea, **kwargs):