File size: 1,963 Bytes
b0df9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import gradio as gr
from transformers import pipeline
import json
import torch


device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = pipeline("text-classification", "iknow-lab/azou", device=device)
model.tokenizer.truncation_side = 'left'

def inference(input, instruction, labels):
    instruction = f"{input} [SEP] {instruction}"
    inputs = model.tokenizer([instruction] * len(labels), labels, truncation=True, padding=True, return_tensors="pt").to(device)
    
    scores = model.model(**inputs).logits.squeeze(1).softmax(-1).tolist()
    output = dict(zip(labels, scores))

    print(instruction)
    print(output)
    return output, json.dumps(output, ensure_ascii=False)


def greet(content, instruction, labels):
    labels = labels.split(",")
    output = inference(content, instruction, labels)
    return output

content = gr.TextArea(label="μž…λ ₯ λ‚΄μš©")
instruction = gr.Textbox(label="μ§€μ‹œλ¬Έ")
labels = gr.Textbox(label="라벨(μ‰Όν‘œλ‘œ ꡬ뢄)")

examples = [
    ["μ˜ˆμ „μ—λŠ” μ£Όλ§λ§ˆλ‹€ κ·Ήμž₯에 λ†€λŸ¬κ°”λŠ”λ° μš”μƒˆλŠ” μ’€ μ•ˆκ°€λŠ” νŽΈμ΄μ—μš”", "λŒ“κΈ€ 주제λ₯Ό λΆ„λ₯˜ν•˜μ„Έμš”", "μ˜ν™”,λ“œλΌλ§ˆ,κ²Œμž„,μ†Œμ„€"],
    ["인천발 KTX와 κ΄€λ ¨ν•œβ€ˆμ†‘λ„μ—­ λ³΅ν•©ν™˜μŠΉμ„Όν„°κ°€β€ˆμ‚¬μ‹€μƒβ€ˆλ¬΄μ‚°,β€ˆλ‹¨μˆœ μ² λ„Β·λ²„μŠ€ μœ„μ£Ό ν™˜μŠΉμ‹œμ„€λ‘œβ€ˆλ§Œλ“€μ–΄μ§„λ‹€.β€ˆμ΄ λ•Œλ¬Έμ— μΈμ²œμ‹œμ˜ 인천발 KTXβ€ˆκΈ°μ μ— μ•΅μ»€μ‹œμ„€μΈ λ³΅ν•©ν™˜μŠΉμ„Όν„°λ₯Ό ν†΅ν•œ μΈκ·Όβ€ˆμ§€μ—­β€ˆκ²½μ œβ€ˆν™œμ„±ν™”λ₯Όβ€ˆμ΄λ€„λ‚Έλ‹€λŠ” κ³„νšμ˜ 차질이 λΆˆκ°€ν”Όν•˜λ‹€.", "κ²½μ œμ— 긍정적인 λ‰΄μŠ€μΈκ°€μš”?", "예,μ•„λ‹ˆμš”"],
    ["λ§ˆμ§€λ§‰μ—λŠ” k팝 곡연보고 쒋은 μΆ”μ–΅ λ‚¨μ•˜μœΌλ©΄ μ’‹κ² λ„€μš”","μš•μ„€μ΄ ν¬ν•¨λ˜μ–΄μžˆλ‚˜μš”?", "μš•μ„€μ΄ μžˆμŠ΅λ‹ˆλ‹€,μš•μ„€μ΄ μ—†μŠ΅λ‹ˆλ‹€"],
]
gr.Interface(fn=greet,
             inputs=[content, instruction, labels],
             outputs=[gr.Label(), gr.Text({}, label="json",)],
             examples=examples).launch(server_name="0.0.0.0",server_port=7860)