mrchuy commited on
Commit
2720910
β€’
1 Parent(s): 8624101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -113
app.py CHANGED
@@ -1,122 +1,26 @@
1
- import io
2
- import random
3
- from typing import List, Tuple
4
-
5
- import aiohttp
6
- import panel as pn
7
- from PIL import Image
8
- # from transformers import CLIPModel, CLIPProcessor
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
- pn.extension(design="bootstrap", sizing_mode="stretch_width")
12
-
13
- ICON_URLS = {
14
- "brand-github": "https://github.com/holoviz/panel",
15
- "brand-twitter": "https://twitter.com/Panel_Org",
16
- "brand-linkedin": "https://www.linkedin.com/company/panel-org",
17
- "message-circle": "https://discourse.holoviz.org/",
18
- "brand-discord": "https://discord.gg/AXRHnJU6sP",
19
- }
20
-
21
-
22
 
23
- def load_tokenizer_model():
24
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/xgen-7b-8k-base", trust_remote_code=True)
25
- model = AutoModelForCausalLM.from_pretrained("Salesforce/xgen-7b-8k-base", torch_dtype=torch.bfloat16)
26
- return tokenizer,model
27
 
28
 
 
29
 
30
- async def process_inputs(class_names: List[str], user_text: str):
31
- """
32
- High level function that takes in the user inputs and returns the
33
- classification results as panel objects.
34
- """
35
- try:
36
- main.disabled = True
37
- if not user_text:
38
- yield "##### ⚠️ Provide some user text URL"
39
- return
40
-
41
- yield "##### βš™ Fetching and running model..."
42
- try:
43
- inputs = tokenizer("The world is", return_tensors="pt")
44
- sample = model.generate(**inputs, max_length=128)
45
-
46
- # pil_img = await open_image_url(image_url)
47
- # img = pn.pane.Image(pil_img, height=400, align="center")
48
- except Exception as e:
49
- yield f"##### πŸ˜” Something went wrong, please try a different URL!"
50
- return
51
-
52
- # class_items = class_names.split(",")
53
- # class_likelihoods = get_similarity_scores(class_items, pil_img)
54
 
55
- # build the results column
56
-
57
- results = pn.Column("##### πŸŽ‰ Here are the results!", tokenizer.decode(sample[0])))
58
-
59
- # for class_item, class_likelihood in zip(class_items, class_likelihoods):
60
- # row_label = pn.widgets.StaticText(
61
- # name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
62
- # )
63
- # row_bar = pn.indicators.Progress(
64
- # value=int(class_likelihood * 100),
65
- # sizing_mode="stretch_width",
66
- # bar_color="secondary",
67
- # margin=(0, 10),
68
- # design=pn.theme.Material,
69
- # )
70
- # results.append(pn.Column(row_label, row_bar))
71
- yield results
72
- finally:
73
- main.disabled = False
74
-
75
-
76
- # create widgets
77
- randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
78
-
79
- image_url = pn.widgets.TextInput(
80
- name="Image URL to classify",
81
- value=pn.bind(random_url, randomize_url),
82
- )
83
- class_names = pn.widgets.TextInput(
84
- name="Comma separated class names",
85
- placeholder="Enter possible class names, e.g. cat, dog",
86
- value="cat, dog, parrot",
87
- )
88
-
89
- input_widgets = pn.Column(
90
- "##### Add some text and do something",
91
- pn.Row(image_url, randomize_url),
92
- class_names,
93
- )
94
-
95
- # add interactivity
96
- interactive_result = pn.panel(
97
- pn.bind(process_inputs, image_url=image_url, class_names=class_names),
98
- height=600,
99
- )
100
-
101
- # add footer
102
- footer_row = pn.Row(pn.Spacer(), align="center")
103
- for icon, url in ICON_URLS.items():
104
- href_button = pn.widgets.Button(icon=icon, width=35, height=35)
105
- href_button.js_on_click(code=f"window.open('{url}')")
106
- footer_row.append(href_button)
107
- footer_row.append(pn.Spacer())
108
 
109
- # create dashboard
110
- main = pn.WidgetBox(
111
- input_widgets,
112
- interactive_result,
113
- footer_row,
114
- )
115
 
116
- title = "Xgen input panel"
117
- pn.template.BootstrapTemplate(
118
- title=title,
119
- main=main,
120
- main_max_width="min(50%, 698px)",
121
- header_background="#F08080",
122
- ).servable(title=title)
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
 
 
 
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ tokenizer = AutoTokenizer.from_pretrained("Salesforce/xgen-7b-8k-base", trust_remote_code=True)
7
+ model = AutoModelForCausalLM.from_pretrained("Salesforce/xgen-7b-8k-base", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
8
 
9
+ print()
 
 
 
10
 
11
 
12
+ # pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
13
 
14
+ def gentext(user_input="The world is"):
15
+ inputs = tokenizer(user_input, return_tensors="pt")
16
+ sample = model.generate(**inputs, max_length=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ return {"output": tokenizer.decode(sample[0])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
20
 
21
+ gr.Interface(
22
+ gentext,
23
+ inputs=gr.inputs.Text(label="Some prompt", type="input"),
24
+ outputs=gr.outputs.Label(num_top_classes=2),
25
+ title="Some prompt",
26
+ ).launch()