|
import glob |
|
import io |
|
import os |
|
import re |
|
import shutil |
|
import sys |
|
from contextlib import closing |
|
from pathlib import Path |
|
|
|
import boto3 |
|
import gradio as gr |
|
import requests |
|
from langchain.agents import Tool, initialize_agent, AgentType |
|
from langchain.chains import LLMChain, LLMMathChain, StuffDocumentsChain, ConversationalRetrievalChain |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.memory import ChatMessageHistory, ConversationBufferMemory |
|
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter |
|
from langchain_community.chat_models import AzureChatOpenAI, ChatOllama |
|
from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader, YoutubeLoader |
|
from langchain_community.llms import AzureOpenAI, Ollama |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_core.prompts import PromptTemplate |
|
from sqlitedict import SqliteDict |
|
|
|
from azure_utils import AzureVoiceData |
|
from polly_utils import PollyVoiceData, NEURAL_ENGINE |
|
|
|
|
|
|
|
|
|
|
|
global_deployment_id = "gpt-4-32k" |
|
global_model_name = "gpt-4-32k" |
|
|
|
ollama_url = "http://10.51.50.39:3000" |
|
ollama_models = ["qwen:72b","yi:34b-chat","deepseek-coder:33b"] |
|
|
|
|
|
chroma_api_impl = "HH_Azure_Openai" |
|
|
|
root_file_path = "./data/" |
|
hr_source_path = "hr_source" |
|
ks_source_path = "ks_source" |
|
believe_source_path = 'be_source' |
|
|
|
sqlite_name = "cache.sqlite3" |
|
sqlite_key="stored_files" |
|
persist_db = "persist_db" |
|
hr_collection_name = "hr_db" |
|
chroma_db_impl="localdb+langchain" |
|
tmp_collection="tmp_collection" |
|
|
|
|
|
inputText = "問題(按q 或Ctrl + c跳出): " |
|
refuse_string="服務被拒. 內容可能涉及敏感字詞,政治,煽動他人或是其他不當言詞, 請改以其他內容嚐試" |
|
|
|
|
|
LOOPING_TALKING_HEAD = "./data/videos/Masahiro.mp4" |
|
TALKING_HEAD_WIDTH = "192" |
|
AZURE_VOICE_DATA = AzureVoiceData() |
|
POLLY_VOICE_DATA = PollyVoiceData() |
|
|
|
prompt_string ="" |
|
|
|
def save_sqlite(key,value): |
|
try: |
|
with SqliteDict(sqlite_name) as mydict: |
|
old_value = mydict[key] |
|
mydict[key] = value+old_value |
|
mydict.commit() |
|
except Exception as ex: |
|
print("Error during storing data (Possibly unsupported):", ex) |
|
|
|
def load_sqlite(key): |
|
try: |
|
with SqliteDict(sqlite_name) as mydict: |
|
value = mydict[key] |
|
return value |
|
except Exception as ex: |
|
print("Error during loading data:", ex) |
|
|
|
def delete_sql(key): |
|
try: |
|
with SqliteDict(sqlite_name) as mydict: |
|
mydict[key] = [] |
|
mydict.commit() |
|
except Exception as ex: |
|
print("Error during storing data (Possibly unsupported):", ex) |
|
|
|
def ai_answer(answer): |
|
print('AI 回答: \033[32m' + answer +'\033[0m') |
|
|
|
def get_llm_model(model_type, model_name): |
|
match model_type: |
|
case "azure": |
|
return AzureOpenAI(deployment_name = global_deployment_id, model_name = global_model_name) |
|
case "ollama": |
|
ollama_model = model_name |
|
match model_name: |
|
case "qwen:72b": |
|
pass |
|
case "yi:34b-chat": |
|
pass |
|
case "deepseek-coder:33b": |
|
pass |
|
case _: |
|
raise gr.Error("the current model is not supported in your Ollama server!") |
|
return Ollama(model=ollama_model, base_url=ollama_url) |
|
|
|
def get_chat_model(model_type, model_name): |
|
match model_type: |
|
case "azure": |
|
return AzureChatOpenAI(deployment_name = global_deployment_id, |
|
model_name = global_model_name) |
|
|
|
|
|
|
|
case "ollama": |
|
ollama_model = model_name |
|
match model_name: |
|
case "qwen:72b": |
|
pass |
|
case "yi:34b-chat": |
|
pass |
|
case "deepseek-coder:33b": |
|
pass |
|
case _: |
|
raise gr.Error("the current model is not supported in your Ollama server!") |
|
return ChatOllama(model=ollama_model, base_url=ollama_url) |
|
|
|
def get_openaiembeddings(): |
|
return OpenAIEmbeddings( |
|
deployment="CivetGPT_embedding", |
|
model="text-embedding-ada-002", |
|
openai_api_base="https://civet-project-001.openai.azure.com/", |
|
openai_api_type="azure", |
|
openai_api_key = "0e3e5b666818488fa1b5cb4e4238ffa7", |
|
chunk_size=1 |
|
) |
|
|
|
def multidocs_loader(files_path, file_ext): |
|
full_files_pattern = "*." + file_ext |
|
loader = DirectoryLoader(files_path, glob=full_files_pattern, show_progress=True) |
|
data = loader.load() |
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10) |
|
documents = text_splitter.split_documents(data) |
|
return documents |
|
|
|
def unstructure_file_loader(filename_path): |
|
loader = UnstructuredFileLoader(filename_path) |
|
data = loader.load() |
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10) |
|
documents = text_splitter.split_documents(data) |
|
return documents |
|
|
|
def add_documents_into_cromadb(db_name, file_path, collection_name): |
|
_db_name = db_name |
|
|
|
documents = multidocs_loader(file_path,"*") |
|
embeddings = get_openaiembeddings() |
|
|
|
chroma_db = Chroma.from_documents( |
|
documents, |
|
embeddings, |
|
collection_name=collection_name, |
|
persist_directory=root_file_path+ persist_db, |
|
|
|
) |
|
|
|
chroma_db.persist() |
|
print('adding documents done!') |
|
|
|
def initial_croma_db(db_name, files_path, file_ext, collection_name): |
|
_db_name = db_name |
|
|
|
documents = multidocs_loader(files_path, file_ext) |
|
embeddings = get_openaiembeddings() |
|
|
|
chroma_db = Chroma.from_documents( |
|
documents, |
|
embeddings, |
|
collection_name = collection_name, |
|
persist_directory= root_file_path+ persist_db, |
|
chroma_db_impl=chroma_db_impl |
|
) |
|
|
|
chroma_db.persist() |
|
print('vectorstore done!') |
|
|
|
def add_files_to_collection(input_file_path, collection_name): |
|
file_path=root_file_path+input_file_path |
|
add_documents_into_cromadb(persist_db, file_path, collection_name) |
|
|
|
def get_prompt_summary_string(): |
|
_local_prompt_string = """使用中文替下面內容做個精簡摘要: |
|
|
|
{text} |
|
|
|
精簡摘要:""" |
|
|
|
if prompt_string == "": |
|
return _local_prompt_string |
|
else: |
|
return prompt_string |
|
|
|
template_string = """ |
|
我是鴻海(等同Foxconn)的員工, 你是一個鴻海的人資專家. |
|
請根據歷史對話,針對這次的問題, 形成獨立問題. 請優先從提供的文件中尋找答案, 你被允許回答不知道, 但回答不知道時需要給中央人資的客服聯絡窗口資訊. |
|
不論什麼問題, 都以中文回答 |
|
|
|
歷史對話: {chat_history} |
|
這次的問題: {question} |
|
人資專家: |
|
""" |
|
|
|
default_legal_contract_prompt = """ |
|
你是一位超級助理, 十分擅長從大量文字中擷取摘要. |
|
以下用 ''' 包含的是保密合約的內容,幫我生成一份2,000個中文字以內保密合約摘要,摘要需要包含以下項目: |
|
1.背景: 介紹對方公司的背景、為什麼要跟該公司簽訂保密合約 |
|
2.目的: 要與對方交換什麼資料, 資料內容與範圍 |
|
3.合約期間:保密合約的時間範圍 |
|
4.提前解約條款: 發生什麼樣的條件就會要提前解約 |
|
5.保密期間: 保密的時間範圍 |
|
6.管轄法院: 如有爭端,雙方同意的管轄法院是哪個法院 |
|
|
|
AI 風險評估: 希望AI 可以評估該資料交換是否有高風險的疑慮; 評估準測: |
|
高風險: 涉及到營業秘密的內容 |
|
中風險: 沒有營業秘密, 但有涉及敏感資料(足以辨識個人的訊息) |
|
低風險: 僅涉及作業面向的訊息 |
|
|
|
保密合約: |
|
''' |
|
{text} |
|
''' |
|
|
|
""" |
|
|
|
default_legal_quotation_prompt = """ |
|
你是一位超級助理, 十分擅長從大量文字中擷取摘要. |
|
以下用 ''' 包含的是報價單的內容,幫我生成一份2,000個中文字以內報價單摘要,摘要需要包含以下項目: |
|
|
|
1. 標的名稱: 報價單中所列出的產品或服務的名稱。 |
|
2. 價格: 報價單中所列出的每個產品或服務的價格, 一定要有正確的幣別與金額數字. |
|
3. 付款內容: 報價單中所列出的付款方式和相關內容, 包括訂金, 交貨款和保留款的金額和支付方式; 除了各款項的交付百分比, 也需要有正確的金額與幣別. |
|
4. 交貨時間: 報價單中所列出的產品或服務的交付的日期或時間範圍。 |
|
5. 保固(英文為Warranty): 請摘要報價單中所有關於保固內容. |
|
6. 維修費用:報價單中所列出的產品或服務的維修費用或相關條款, 有任何維修的金額請一定要列出. |
|
7. 貿易條件(Trade Term) |
|
8. 其他注意事項:報價單中所列出的其他重要事項或注意事項。 |
|
|
|
請根據報價單的內容, 生成一份清晰明確的摘要, 條列式地把摘要列出, 確保所有項目都被包含在內. 如果內容超過三句話, 請以子項目的方式逐一列舉出來. |
|
|
|
請注意,生成的摘要應該是簡潔且易於理解的, 要詳細條列出內容, 不可產生 "依其他文件說明" 等說明方式. |
|
在報價單裡沒有找到符合的資訊, 你被允許回答 "無相關資料". |
|
|
|
報價單內容: |
|
|
|
''' |
|
{text} |
|
''' |
|
""" |
|
|
|
def get_prompt_template_string(): |
|
print("template:"+template_string) |
|
return template_string |
|
|
|
def get_default_template_prompt(): |
|
template = "你是個知識廣泛的超級助手, 以下所有問題請用中文回答, 並請在500個中文字以內來解釋 {concept} 概念" |
|
prompt = PromptTemplate( |
|
input_variables = ["concept"], |
|
template = template |
|
) |
|
|
|
return prompt |
|
|
|
def chat_conversation(): |
|
print("resource: " + global_deployment_id + " / " + global_model_name) |
|
chat = AzureChatOpenAI( |
|
deployment_name = global_deployment_id, |
|
model_name = global_model_name, |
|
) |
|
|
|
history = ChatMessageHistory() |
|
history.add_ai_message("你是一個超級助理, 以下問題都用中文回答") |
|
while 1: |
|
text = input(inputText) |
|
if text == 'q': |
|
break |
|
history.add_user_message(text) |
|
ai_response = chat(history.messages) |
|
ai_answer(ai_response.content) |
|
|
|
def local_vector_search(question_str, |
|
chat_history, |
|
collection_name = hr_collection_name): |
|
embedding = get_openaiembeddings() |
|
vectorstore = Chroma( embedding_function=embedding, |
|
collection_name=collection_name, |
|
persist_directory=root_file_path+persist_db, |
|
) |
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, ai_prefix = "AI超級助理") |
|
|
|
llm = get_llm_model(chat_model_type = "azure") |
|
chat_llm = get_chat_model(chat_model_type = "azure") |
|
|
|
prompt = PromptTemplate( |
|
template=get_prompt_template_string(), |
|
input_variables=["question","chat_history"] |
|
) |
|
prompt.format(question=question_str,chat_history=chat_history) |
|
km_chain = ConversationalRetrievalChain.from_llm( |
|
llm=chat_llm, |
|
retriever=vectorstore.as_retriever(), |
|
memory=memory, |
|
condense_question_prompt=prompt, |
|
) |
|
km_tool = Tool( |
|
name='Knowledge Base', |
|
func=km_chain.run, |
|
description='一個非常有用的工具, 當要查詢任何公司政策以及鴻海相關資料都使用這個工具' |
|
) |
|
|
|
math_math = LLMMathChain(llm=llm,verbose=True) |
|
math_tool = Tool( |
|
name='Calculator', |
|
func=math_math.run, |
|
description='Useful for when you need to answer questions about math.' |
|
) |
|
|
|
tools=[math_tool,km_tool] |
|
agent=initialize_agent( |
|
agent=AgentType.OPENAI_FUNCTIONS, |
|
tools=tools, |
|
llm=chat_llm, |
|
verbose=True, |
|
memory=memory, |
|
max_iterations=30, |
|
) |
|
|
|
result=km_chain(question_str) |
|
|
|
print(result) |
|
return result["answer"] |
|
|
|
def make_markdown_table(array): |
|
nl = "\n" |
|
markdown = "" |
|
for entry in array: |
|
markdown += f"{entry} {nl}" |
|
return markdown |
|
|
|
def get_hr_files(): |
|
files = load_sqlite(sqlite_key) |
|
if files == None: |
|
return |
|
else: |
|
return make_markdown_table(files) |
|
|
|
def update_hr_km(files): |
|
file_paths = [file.name for file in files] |
|
dest_file_path=root_file_path+hr_source_path |
|
if not os.path.exists(dest_file_path): |
|
os.makedirs(dest_file_path) |
|
|
|
for file in file_paths: |
|
shutil.copy(file, dest_file_path) |
|
add_files_to_collection(hr_source_path, hr_collection_name) |
|
|
|
save_sqlite(sqlite_key, [Path(file_path).name for file_path in file_paths]) |
|
return get_hr_files() |
|
|
|
def clear_all_collection(collection_name): |
|
pass |
|
|
|
def all_files_under_diretory(path): |
|
files = glob.glob(path+'\*') |
|
for f in files: |
|
os.remove(f) |
|
|
|
def clear_hr_datas(): |
|
|
|
client = get_chroma_client(hr_collection_name) |
|
client.delete_collection(name=hr_collection_name) |
|
print("Collection removed completely!") |
|
|
|
|
|
all_files_under_diretory(root_file_path+hr_source_path) |
|
delete_sql(sqlite_key) |
|
return get_hr_files() |
|
|
|
def num_of_collection(collection_name): |
|
client = get_chroma_client(collection_name) |
|
number = client.get_collection(collection_name).count() |
|
return f"目前知識卷裡有{number}卷項目" |
|
|
|
def clear_tmp_collection(): |
|
client = get_chroma_client(tmp_collection) |
|
client.delete_collection(name=tmp_collection) |
|
all_files_under_diretory(root_file_path+ks_source_path) |
|
return num_of_collection(tmp_collection) |
|
|
|
def content_summary(split_documents): |
|
global current_chatllm |
|
llm = current_chatllm |
|
_local_prompt_string = get_prompt_summary_string() |
|
print("prompt_string: "+_local_prompt_string) |
|
_local_prompt = PromptTemplate.from_template(_local_prompt_string) |
|
|
|
llm_chain = LLMChain(llm=llm, prompt=_local_prompt) |
|
|
|
|
|
stuff_chain = StuffDocumentsChain( |
|
llm_chain=llm_chain, document_variable_name="text" |
|
) |
|
output = stuff_chain.invoke(split_documents) |
|
return output |
|
|
|
def pdf_summary(file_name): |
|
print("file_name: "+file_name) |
|
loader = UnstructuredFileLoader(file_name, mode="elements", strategy="fast",) |
|
document = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=20 |
|
) |
|
split_documents = text_splitter.split_documents(document) |
|
return content_summary(split_documents) |
|
|
|
def youtube_summary(youtube_url): |
|
loader=YoutubeLoader.from_youtube_url(youtube_url, add_video_info=True, language=['en','zh-TW'], translation='zh-TW') |
|
document=loader.load() |
|
text_splitter=CharacterTextSplitter(chunk_size=1000, chunk_overlap=10) |
|
split_documents=text_splitter.split_documents(document) |
|
result = content_summary(split_documents) |
|
return result['output_text'] |
|
|
|
def summary_large_file(files): |
|
file_paths = [file.name for file in files] |
|
print(file_paths[0]) |
|
result = pdf_summary(file_paths[0]) |
|
return result["output_text"] |
|
|
|
def upload_large_file(files): |
|
file_paths = [file.name for file in files] |
|
return Path(file_paths[0]).stem |
|
|
|
|
|
def get_chroma_client(collection_name): |
|
vectorstore = Chroma( |
|
embedding_function=get_openaiembeddings(), |
|
collection_name=collection_name, |
|
persist_directory= root_file_path+persist_db, |
|
) |
|
return vectorstore._client |
|
|
|
def create_db(): |
|
files_path = root_file_path+hr_source_path |
|
file_ext = "pdf" |
|
initial_croma_db(persist_db, files_path, file_ext, hr_collection_name) |
|
|
|
|
|
def create_html_video(file_name, width, temp_file_url): |
|
html_video = f'<video width={width} height={width} autoplay muted loop><source src={temp_file_url} type="video/mp4" poster="Masahiro.png"></video>' |
|
return html_video |
|
|
|
def do_html_audio_speak(words_to_speak): |
|
polly_client = boto3.Session( |
|
aws_access_key_id="AKIAV7Q7AAGW54RBR6FZ", |
|
aws_secret_access_key="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", |
|
region_name='us-west-2' |
|
).client('polly') |
|
|
|
language_code="cmn-CN" |
|
engine = NEURAL_ENGINE |
|
voice_id = "Zhiyu" |
|
|
|
print("voice_id: "+voice_id+"\nlanguage_code="+language_code) |
|
response = polly_client.synthesize_speech( |
|
Text=words_to_speak, |
|
OutputFormat='mp3', |
|
VoiceId=voice_id, |
|
LanguageCode=language_code, |
|
Engine=engine |
|
) |
|
|
|
html_audio = '<pre>no audio</pre>' |
|
|
|
|
|
if "AudioStream" in response: |
|
with closing(response["AudioStream"]) as stream: |
|
try: |
|
with open('./data/audios/tempfile.mp3', 'wb') as f: |
|
f.write(stream.read()) |
|
temp_aud_file = gr.File("./data/audios/tempfile.mp3") |
|
temp_aud_file_url = "/file=" + temp_aud_file.value['name'] |
|
html_audio = f'<audio autoplay><source src={temp_aud_file_url} type="audio/mp3"></audio>' |
|
except IOError as error: |
|
|
|
print(error) |
|
return None, None |
|
else: |
|
|
|
print("Could not stream audio") |
|
return None, None |
|
|
|
return html_audio, "./data/audios/tempfile.mp3" |
|
|
|
def do_html_video_speak(): |
|
|
|
key = "eyJhbGciOiJIUzUxMiJ9.eyJ1c2VybmFtZSI6ImNhdHNreXR3QGdtYWlsLmNvbSJ9.OypOUZF-xv4-b8i9F4_aaMQiJpxv0mXRT5kyuJwTMXVd4awV-O-Obntp--AqGghNNowzQ9oG7zArSnQjz2vQgg" |
|
url = "https://api.exh.ai/animations/v2/generate_lipsync_from_audio" |
|
files = {"audio_file": ("./data/audios/tempfile.mp3", open("./data/audios/tempfile.mp3", "rb"), "audio/mpeg")} |
|
payload = { |
|
"animation_pipeline": "high_quality", |
|
"idle_url": "https://ugc-idle.s3-us-west-2.amazonaws.com/5fd9ba1b1607b39a4d559300c1e35bee.mp4" |
|
} |
|
headers = { |
|
"accept": "application/json", |
|
"authorization": f"Bearer {key}" |
|
} |
|
|
|
res = requests.post(url, data=payload, files=files, headers=headers) |
|
|
|
print("res.status_code: ", res.status_code) |
|
|
|
html_video = '<pre>no video</pre>' |
|
if isinstance(res.content, bytes): |
|
response_stream = io.BytesIO(res.content) |
|
print("len(res.content)): ", len(res.content)) |
|
|
|
with open('./data/videos/tempfile.mp4', 'wb') as f: |
|
f.write(response_stream.read()) |
|
temp_file = gr.File("./data/videos/tempfile.mp4") |
|
temp_file_url = "/file=" + temp_file.value['name'] |
|
html_video = f'<video width={TALKING_HEAD_WIDTH} height={TALKING_HEAD_WIDTH} autoplay><source src={temp_file_url} type="video/mp4" poster="Masahiro.png"></video>' |
|
else: |
|
print('video url unknown') |
|
return res, html_video, "./data/videos/tempfile.mp4" |
|
|
|
def kh_update_km(files): |
|
file_paths = [file.name for file in files] |
|
dest_file_path = root_file_path + ks_source_path |
|
|
|
if not os.path.exists(dest_file_path): |
|
os.makedirs(dest_file_path) |
|
|
|
for file in file_paths: |
|
shutil.copy(file, dest_file_path) |
|
add_files_to_collection(ks_source_path, tmp_collection) |
|
|
|
return num_of_collection(tmp_collection) |
|
|
|
def generate_autolayout(description, template): |
|
llm = get_chat_model(model_type = "azure", model_name="") |
|
_template= PromptTemplate( |
|
input_variables=["text"], template=template |
|
) |
|
llm_chain = LLMChain(llm=llm, prompt=_template) |
|
layout_output = llm_chain.invoke(description) |
|
print(layout_output) |
|
return layout_output |
|
|
|
class Logger: |
|
def __init__(self, filename): |
|
self.terminal = sys.stdout |
|
self.log = open(filename, "w", encoding='UTF-8') |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
|
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush() |
|
|
|
def isatty(self): |
|
return False |
|
|
|
def read_logs(): |
|
sys.stdout.flush() |
|
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') |
|
|
|
with open("output.log", "r", encoding='UTF-8') as f: |
|
return ansi_escape.sub('', f.read()) |
|
|
|
def lunch_style(demo): |
|
sys.stdout = Logger("output.log") |
|
demo.load(read_logs, None, None, every=1) |
|
|
|
if len(sys.argv)==1: |
|
print("running server as default value") |
|
demo.launch(ssl_verify=False, share=True, allowed_paths=[root_file_path, root_file_path+hr_source_path]) |
|
elif len(sys.argv)==2 and sys.argv[1] == "server": |
|
local_ip = "10.51.50.39" |
|
local_port = int(3100) |
|
print(f"running server on http://{local_ip}:{local_port}") |
|
demo.launch(ssl_verify=False, share=True, allowed_paths=[root_file_path, root_file_path+hr_source_path],auth=("Foxconn", "Foxconn123!"),server_name=local_ip, server_port=local_port) |
|
elif len(sys.argv)==4: |
|
local_ip = sys.argv[2] |
|
local_port = sys.argv[3] |
|
print(f"running server on http://{local_ip}:{local_port}") |
|
_local_port = int(local_port) |
|
demo.launch(ssl_verify=False, share=True, allowed_paths=[root_file_path, root_file_path+hr_source_path],auth=("Foxconn", "Foxconn123!"),server_name=local_ip, server_port=_local_port) |
|
else: |
|
print("syntax: python <your_app>.py [server {ip_address, port}] ") |
|
|
|
|
|
def poc_init(): |
|
global current_chatllm |
|
global current_llm |
|
current_llm = get_llm_model(model_type="azure",model_name="") |
|
current_chatllm = get_chat_model(model_type="azure", model_name="") |
|
|
|
def gradio_run(): |
|
print("User Login") |
|
poc_init() |
|
with gr.Blocks(theme='bethecloud/storj_theme') as demo: |
|
with gr.Row(): |
|
gr.Markdown("# HH Azure Openai Demo") |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
------ |
|
## Playground |
|
請切換下方Tab 鍵試驗各項功能 |
|
|
|
""") |
|
with gr.Tab("法務AI幫手"): |
|
legal_path = "./data/" |
|
quotation_file = "legal_quotation_prompt.txt" |
|
contract_file = "legal_contract_prompt.txt" |
|
|
|
def load_prompt_from_file(typeString): |
|
if typeString == "保密合約": |
|
_path_string = legal_path + contract_file |
|
else: |
|
_path_string = legal_path + quotation_file |
|
f = open(_path_string, 'r', encoding="utf-8") |
|
return_string= f.read() |
|
f.close() |
|
return return_string |
|
def save_func(typeString, prompt_string): |
|
if typeString == "保密合約": |
|
_path_string = legal_path + contract_file |
|
else: |
|
_path_string = legal_path + quotation_file |
|
f = open(_path_string, "w", encoding="utf-8") |
|
f.write(prompt_string) |
|
f.close() |
|
def restore_func(typeString): |
|
if typeString == "保密合約": |
|
content_string = default_legal_contract_prompt |
|
else: |
|
content_string = default_legal_quotation_prompt |
|
save_func(typeString, content_string) |
|
return content_string |
|
def change_prompt(inputString): |
|
global prompt_string |
|
prompt_string = inputString |
|
return inputString |
|
|
|
def change_model_var(input_component): |
|
global current_chatllm |
|
global current_llm |
|
print(input_component) |
|
match input_component: |
|
case "Azure-GPT4": |
|
_current_llm = get_llm_model(model_type="azure",model_name="") |
|
_current_chatllm = get_chat_model(model_type="azure", model_name="") |
|
case "通義千問1.5-72B": |
|
_current_llm = get_llm_model(model_type="ollama", model_name="qwen:72b") |
|
_current_chatllm = get_chat_model(model_type="ollama", model_name="qwen:72b") |
|
case "零一萬物-34B": |
|
_current_llm = get_llm_model(model_type="ollama", model_name="yi:34b-chat") |
|
_current_chatllm = get_chat_model(model_type="ollama", model_name="yi:34b-chat") |
|
case _: |
|
raise gr.Error("the model is not supported in your Ollama server!") |
|
current_chatllm = _current_chatllm |
|
current_llm = _current_llm |
|
|
|
gr.Markdown(""" |
|
### 面版說明: |
|
操作介面全部都在左側, 右側是摘要內容. |
|
### 操作步驟 |
|
1. 選擇摘要的類型: 選 `保密合約` 或 `報價單` |
|
2. 微調prompt內容: 直接點選 `prompt對話框` 修改文字內容 |
|
3. 上傳檔案: 支援PDF/doc/docx 等格式 |
|
|
|
""") |
|
gr.Markdown(""" |
|
--- |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
ai_model = gr.Radio(choices=["Azure-GPT4", "通義千問1.5-72B", "零一萬物-34B"], label="1. 選擇AI模型", info="除Azure之外, 其他均是內部AI Model", type="value", value="Azure-GPT4", interactive=True) |
|
with gr.Group(): |
|
contract_type = gr.Radio(choices=["報價單","保密合約"], |
|
label="2. 請選擇摘要類型", |
|
info="選擇不一樣的摘要類型,會改變下方的prompt 內容", |
|
type="value", |
|
value="報價單", |
|
interactive=True) |
|
_firstString = load_prompt_from_file("報價單") |
|
with gr.Group(): |
|
prompt_textbox = gr.Textbox(_firstString, |
|
lines=20, |
|
max_lines=20, |
|
label="3. 設定Prompt內容", |
|
interactive=True) |
|
prompt_textbox.change(change_prompt, inputs=prompt_textbox) |
|
saveBtn = gr.Button("保存現有Prompt") |
|
restoreBtn = gr.Button("回覆預設Prompt") |
|
|
|
with gr.Group(): |
|
file_name_field = gr.Textbox(max_lines=1, label="4. 上傳檔案(可接受text, pdf, docx, csv 格式)", placeholder="目前沒有上傳保密合約或報價單") |
|
upload_button = gr.UploadButton("上傳", |
|
file_types=["text", ".pdf", ".csv", ".docx", ".doc"], file_count="multiple") |
|
|
|
with gr.Column(): |
|
summary_text = gr.Textbox() |
|
summary_text.label = "AI 摘要:" |
|
summary_text.change = False |
|
summary_text.lines = 38 |
|
summary_text.max_lines = 38 |
|
|
|
contract_type.change(fn=load_prompt_from_file, inputs=contract_type, outputs=prompt_textbox) |
|
ai_model.input(fn=change_model_var, inputs=ai_model) |
|
|
|
saveBtn.click(save_func, inputs=[contract_type, prompt_textbox],) |
|
restoreBtn.click(restore_func, inputs=contract_type, outputs=prompt_textbox) |
|
upload_button.upload(upload_large_file, upload_button, file_name_field).\ |
|
then(change_prompt,inputs=prompt_textbox).\ |
|
then(summary_large_file, upload_button, summary_text) |
|
with gr.Tab("設計系統"): |
|
with gr.Row(): |
|
def change_prompt(inputString): |
|
template_string = inputString |
|
return template_string |
|
def auto_layout(description, template): |
|
html_string = generate_autolayout(description, template) |
|
return html_string |
|
with gr.Column(scale=1): |
|
file_list = gr.Textbox(get_hr_files, label="1. 已存在知識庫的檔案", |
|
placeholder="沒有任何檔案存在", max_lines=5, lines=5) |
|
upload_button = gr.UploadButton("上傳UX/UI 知識庫檔案(text,pdf,docx,csv)",file_types=["text", ".pdf", ".docx", ".csv"],file_count="multiple") |
|
upload_button.upload(update_hr_km, inputs=upload_button, outputs=file_list) |
|
cleanDataBtn = gr.Button(value="刪除所有知識以及檔案") |
|
cleanDataBtn.click(clear_hr_datas, outputs=file_list) |
|
with gr.Column(scale=4): |
|
autolayout_template = """你是一個資深的UX designer, 請依下方的需求, |
|
給我一個web UI 的設計畫面, 使用html 語法; 除html 語法之外, 你不要說其他的話 |
|
|
|
需求: |
|
''' |
|
{text} |
|
''' |
|
""" |
|
prompt_textbox = gr.Textbox(autolayout_template, lines=8, max_lines=8, label="2. Prompt") |
|
update_btn = gr.Button("更新Prompt") |
|
with gr.Row(): |
|
msg = gr.Textbox(placeholder="輸入說明讓AI 明白你想要的畫面", label="3. 描述") |
|
with gr.Row(): |
|
content = gr.HTML("") |
|
update_btn.click(change_prompt, prompt_textbox, prompt_textbox) |
|
msg.submit(auto_layout, [msg,prompt_textbox], content, queue=True) |
|
|
|
with gr.Tab("Talk to Data"): |
|
def messageHandle(message, history): |
|
return message |
|
def upload_file(files): |
|
file_paths = [file.name for file in files] |
|
print(file_paths) |
|
return file_paths |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file_output = gr.File() |
|
upload_button = gr.UploadButton("Click to Upload a File", file_types=["pdf"], file_count="multiple") |
|
upload_button.upload(upload_file, upload_button, file_output) |
|
with gr.Column(scale=2): |
|
gr.ChatInterface(fn=messageHandle, submit_btn=None, examples=["請摘要這份文件的內容", "請列出這份會議紀錄的逐字稿"]) |
|
pass |
|
demo.queue(concurrency_count=10) |
|
lunch_style(demo) |
|
|
|
def test(): |
|
from langchain_community.llms import Ollama |
|
llm = Ollama(model="llama2", base_url="http://10.51.50.39:3000") |
|
print(llm.invoke("請用繁體中文摘要下列文章: 依據聯合國 2014 年資料顯示,2050年都市人口將超過全球人口的70%,且居住人口超過千萬人的巨型城市(Megacity)將達到29個,以及幾個大城市組成的集合城市(Conurbation),如德國魯爾區(Ruhr)、荷蘭蘭斯台德(Randstad)、美國紐約與紐澤西等。因為都市人口密度的持續增加,將會帶來交通、安全、汙染及醫療等城市居住與治理的新挑戰。換句話說,都市化導致城市擠得水洩不通,隨之衍生的問題,要靠科技來超前部署。「智慧城市」最早源自IBM的智慧地球(Smart Planet)此一概念,意即以智慧運算系統來改善現在以及未來的生活,其後也衍生出多種不同的概念性名詞,如:資訊城市(Information City)、數位城市(Digital City)及無所不在的城市(Ubiquitous City)等,以各種城市智慧基礎建設投資、進行資通訊技術發展與應用,以達到城市的永續發展、改善人民生活品質與提升城市競爭力。")) |
|
|
|
gradio_run() |
|
|
|
|
|
|