merve's picture
merve HF staff
Create app.py
dc4cc7d
raw
history blame
1.89 kB
import os
import gradio as gr
from gradio import FlaggingCallback
from gradio.components import IOComponent
from datasets import load_dataset
from typing import List, Optional, Any
import argilla as rg
import os
def load_data():
ds = load_dataset("merve/turkish_instructions", split="train", streaming=True)
sample = next(iter(ds))
return sample
def create_record(sample, feedback):
status = "Validated" if feedback == "Doğru" else "Default"
#sample = next(iter(ds))
fields = {
"talimat": sample["talimat"],
"input": sample["giriş"],
"response": sample["Γ§Δ±ktΔ±"]
}
# the label will come from the flag object in Gradio
label = "True"
record = rg.TextClassificationRecord(
inputs=fields,
annotation=label,
status=status,
metadata={"feedback": feedback}
)
print(record)
return record
class ArgillaLogger(FlaggingCallback):
def __init__(self, api_url, api_key, dataset_name):
rg.init(api_url=api_url, api_key=api_key)
self.dataset_name = dataset_name
def setup(self, components: List[IOComponent], flagging_dir: str):
pass
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
text = flag_data[0]
inference = flag_data[1]
rg.log(name=self.dataset_name, records=create_record(text, flag_option))
gr.Interface(
title = "ALPACA Veriseti DΓΌzeltme ArayΓΌzΓΌ",
description = "",
allow_flagging="manual",
flagging_callback=ArgillaLogger(
api_url="https://sandbox.argilla.io",
api_key=os.getenv("TEAM_API_KEY"),
dataset_name="alpaca-flags"
),
flagging_options=["Doğru", "Yanlış", "Belirsiz"]
).launch()