File size: 1,892 Bytes
cd607b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# + tags=["hide_inp"]

desc = """
# NER

Notebook implementation of named entity recognition.
Adapted from [promptify](https://github.com/promptslab/Promptify/blob/main/promptify/prompts/nlp/templates/ner.jinja).
"""
# -

import json

import minichain

# Prompt to extract NER tags as json

class NERPrompt(minichain.TemplatePrompt):
    template_file = "ner.pmpt.tpl"

    def parse(self, response, inp):
        return json.loads(response)

# Use NER to ask a simple queston.

class TeamPrompt(minichain.Prompt):
    def prompt(self, inp):
        return "Can you describe these basketball teams? " + \
            " ".join([i["E"] for i in inp if i["T"] =="Team"])

    def parse(self, response, inp):
        return response

# Run the system.

with minichain.start_chain("ner") as backend:
    ner_prompt = NERPrompt(backend.OpenAI())
    team_prompt = TeamPrompt(backend.OpenAI())
    prompt = ner_prompt.chain(team_prompt)
    # results = prompt(
    #     {"text_input": "An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.",
    #      "labels" : ["Team", "Date"],
    #      "domain": "Sports"
    #      }
    # )
    # print(results)

gradio = prompt.to_gradio(fields =["text_input", "labels", "domain"],
                          examples=[["An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.", "Team, Date", "Sports"]],
                          description=desc)

    
if __name__ == "__main__":
    gradio.launch()

    
# View prompt examples.

# + tags=["hide_inp"]
# NERPrompt().show(
#     {
#         "input": "I went to New York",
#         "domain": "Travel",
#         "labels": ["City"]
#     },
#     '[{"T": "City", "E": "New York"}]',
# )
# # -

# # View log.

# minichain.show_log("ner.log")