import streamlit as st from random import choices, randint import os os.system("pip install transformers") os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu") os.system("pip install einops") os.system("pip install sentencepiece") os.system("pip install openai") import json from transformers import AutoModelForCausalLM, AutoTokenizer import torch from openai import OpenAI # """ # 谁是卧底游戏 # """ state = st.session_state # title st.title("谁是卧底😎") # read the word list from the json file with open("word_list.json", "r") as f: word_list = json.load(f) pass # define the avatar dict for the players avatar_dict = { "host": "🐼", "P1": "🚀", "P2": "🚄", "P3": "🚁", "P4": "🚂", "P5": "🚢", "P6": "🚤", "P7": "🚙", "P8": "🚠", "P9": "🚲", "P10": "🚜", "H": "🤹‍♂️", } # define the state of the game and save some data ## 全局消息栈 if "messages" not in state: state.messages = [] pass ## 玩家列表 if "players" not in state: state.players = [] pass ## 玩家系统提示 if "prompt" not in state: with open("prompt.txt", "r") as f: state.prompt = f.read() pass pass # create a new OpenAI client and define the generation function if "client" not in state: state.client = OpenAI( api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("BASE_URL") ) state.model_name = "internlm/internlm2_5-20b-chat" # model_path = "internlm/internlm2_5-7b-chat" # state.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True) # state.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # state.model.eval() pass def response(messages): return state.client.chat.completions.create( model=state.model_name, messages=messages, temperature=0.5, ).choices[0].message.content # prompt = state.tokenizer.apply_chat_template( # messages, # tokenize=False, # add_generation_prompt=True # ) # response,history = state.model.chat(state.tokenizer,prompt,history=[]) # return response # settings ## 记录轮次 if "max_round" not in state: state.max_round = 100 pass if "round" not in state: state.round = 0 pass ## 侧边栏设置游戏,包括关键词,人数,回合数等 with st.sidebar: st.write("游戏设置") with st.form(key="game_setting"): if "words" not in state: words_num = randint(0, len(word_list)-1) state.words = word_list[words_num] total_num = st.number_input("总人数", 5, 10, 5) spy_num = st.number_input("卧底人数", 1, total_num//2, 1) max_round = st.number_input("最大回合数", 5, 10, 10) submitted = st.form_submit_button("保存设置") ## 提交后保存设置,初始化玩家、消息栈等 if submitted: # print("游戏设置已保存") state.spy_word = state.words["spy_word"] # 卧底关键词 state.civilian_word = state.words["civilian_word"] # 平民关键词 state.total_num = total_num # 总人数 state.spy_num = spy_num # 卧底人数 state.max_round = max_round # 最大回合数 ## 初始化玩家列表,人类玩家和AI玩家分开存 human_dignity = randint(0,1) # 人类玩家的身份,0: 平民 1: 卧底 if human_dignity == 0: state.players = [{"id": "H", "dignity": "civilian"}] st.write("你的关键词是{}".format(state.civilian_word)) pass else: state.players = [{"id": "H", "dignity": "spy"}] st.write("你的关键词是{}".format(state.spy_word)) pass state.players += [{"id":"P"+str(i+1)} for i in range(total_num-1)] if human_dignity == 1 and state.spy_num-1 == 0: spy_id = [] pass else: spy_id = choices([f"P{a}" for a in list(range(1,total_num))], k=state.spy_num) pass for each in state.players: if each["id"] in spy_id: each["dignity"] = "spy" pass else: each["dignity"] = "civilian" pass pass pass pass pass # 消息显示窗口 with st.container(height=300): for message in state.messages: with st.chat_message(message["id"], avatar=avatar_dict[message["id"]]): st.text(message["id"]) st.write(message["message"]) pass # 游戏主体环节 if state.round < state.max_round: if "description" not in state: state.description = [] pass ## 控制游戏轮次及开始 start = st.button("开始第{}轮游戏".format(state.round+1)) if start: if "round" not in state: state.round = 0 pass state.messages.append({"id":"host", "message":f"第{state.round+1}轮游戏开始"}) ## 生成描述环节 for player in state.players: ## 如果是人类玩家,跳过 if player["id"] == "H": continue if player["dignity"] == "spy": text = response([ {"role":"system", "content":state.prompt.format(state.spy_word)}, {"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,需要注意不能与已有的描述重复,下面是描述历史记录:\n"+"\n".join(state.description)} ]) else: text = response([ {"role":"system", "content":state.prompt.format(state.civilian_word)}, {"role":"system", "content":"/describe "+"请根据描述历史记录描述你的关键词,下面是描述历史记录:\n"+"\n".join(state.description)} ]) pass state.messages.append({"id":player["id"], "message": text}) state.description.append(player["id"] + ":" + text) pass st.rerun() pass ## 投票环节, AI玩家生成回复,最后由人类玩家统计并选择投票对象 col1, col2 = st.columns([8,2]) with col1: vote_id = st.selectbox("投票对象", [a["id"] for a in state.players]) pass with col2: if st.button("开始投票"): for player in state.players: if player["id"] == "H": continue text = response([ {"role":"system", "content":state.prompt.format(state.spy_word)}, {"role":"system", "content":"/vote "+"请根据描述历史记录选择要投出的玩家,下面是描述历史记录:\n"+"\n".join(state.description)+"\n"+"当前场上存活玩家id为:\n"+",".join([a["id"] for a in state.players])} ]) state.messages.append({"id":player["id"], "message": text}) pass st.rerun() pass pass if st.button("投出玩家"): state.messages.append({"id":"host", "message":f"玩家{vote_id}被投票出局"}) state.round += 1 for player in state.players: if player["id"] == vote_id: state.players.remove(player) break pass ## 验证是否还有卧底存活 spy_live = False for player in state.players: if player["dignity"] == "spy": spy_live = True break pass if not spy_live and state.players: state.messages.append({"id":"host", "message":"平民胜利!"}) pass elif spy_live: ## 统计当前卧底人数,如果占据一半以上则卧底胜利 spy_num = 0 for player in state.players: if player["dignity"] == "spy": spy_num += 1 pass pass if spy_num >= len(state.players)//2: state.messages.append({"id":"host", "message":"卧底胜利!"}) pass st.rerun() pass human_live = False for player in state.players: if player["id"] == "H": human_live = True break pass if human_live: if des := st.chat_input(): state.messages.append({"id":"H", "message":des}) state.description.append("H:"+des) st.rerun()