File size: 2,243 Bytes
6b608bb
 
8ea22c2
6b608bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

#A demo of a side functionality of Guillaume-Tell: guessing whether the question should open up a source retrieval pipeline.
#The function should return a structured answer in json with two components:
##A short analysis with reasoning.
##A boolean answer in French ("oui" or "non")


import sys, os
from pprint import pprint
from jinja2 import Environment, FileSystemLoader, meta
import yaml

import pandas as pd
from vllm import LLM, SamplingParams


sys.path.append(".")
os.chdir(os.path.dirname(os.path.abspath(__file__)))

#Specific function to deal with json format.
def get_llm_response(prompt_template, sampling_params):
  prompts = [prompt_template]
  outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
  generated_text = outputs[0].outputs[0].text
  if generated_text[-1] != "}":
    generated_text = generated_text + "}"
  prompt = prompt_template + generated_text
  return prompt, generated_text

if __name__ == "__main__":

    with open('prompt_config.yaml') as f:
        config = yaml.safe_load(f)

    print("prompt format:", config.get("prompt_format"))
    print(config)
    print()
    for prompt in config["prompts"]:
        if prompt["mode"] == "analysis":
            print(f'--- prompt mode: {prompt["mode"]} ---')
            env = Environment(loader=FileSystemLoader("."))
            template = env.get_template(prompt["template"])

            source = template.environment.loader.get_source(template.environment, template.name)
            variables = meta.find_undeclared_variables(env.parse(source[0]))

            print("variables:", variables)
            print("---")

            data = {"query": "Comment obtenir le formulaire A36 ?"}
            if "system_prompt" in variables:
                data["system_prompt"] = prompt["system_prompt"]

            rendered_template = template.render(**data)
            print(rendered_template)
            print("---")

            llm = LLM("mistral-mfs-reference-2/mistral-mfs-reference-2")

            sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=300, stop="}")
            
            prompt, generated_text = get_llm_response(rendered_template, sampling_params)
            print("Albert : ", generated_text)