File size: 4,773 Bytes
8de89ad
 
 
 
 
908f4f9
8de89ad
 
d9c46e5
8de89ad
 
 
60632fd
4ca4fe8
a21f45a
dc8c26d
65124fb
dc8c26d
 
a21f45a
 
 
fe930dc
8de89ad
 
8b3e201
8de89ad
 
 
 
 
 
999ca47
8b3e201
5512dad
 
808f2e9
8de89ad
b386fa8
 
 
808f2e9
b386fa8
 
 
808f2e9
8de89ad
 
 
 
 
273f67e
8de89ad
 
 
 
7f7fd94
8de89ad
273f67e
8de89ad
 
 
 
 
10d12af
d09c314
8b2c11c
8de89ad
 
908f4f9
8b3e201
8de89ad
a8a4a02
 
5512dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8de89ad
 
 
 
5512dad
 
8de89ad
 
 
 
 
 
b386fa8
8de89ad
 
 
 
 
5512dad
8de89ad
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import logging
import os
import re
import string
import html

import gradio as gr
import nh3
from elasticsearch import Elasticsearch
from elasticsearch_dsl import Search, Q

# es = Elasticsearch(os.environ.get("host"), timeout=100, http_compress=True, maxsize=1000)
es = Elasticsearch(os.environ.get("host"), http_compress=True, timeout=200)
def mark_tokens_bold(text, query):
    if query.startswith('"') and query.endswith('"'):
        tokens = query[1:-1]
    else:
        tokens = query.split(" ")
    for token in tokens:
        pattern = re.escape(token) #r"\b" + re.escape(token) + r"\b"
        text = re.sub(pattern, "<span style='color: #e6b800;'><b>" + token + "</b></span>", text)
    return text


def process_results(results, query):
    if len(results) == 0:
        return """<br><p>No results retrieved.</p><br><hr>"""

    results_html = ""
    for result in results:
        text_html = result["text"]
        # text_html = mark_tokens_bold(text_html, query)
        repository = result['repository']
        license = result["license"]
        language = result["language"]
        code_height = min(600, len(text_html.split('\n')) * 20) # limit to maximum height of 600px
        results_html += """\
        <p style='font-size:16px; text-align: left;'><b>Source: </b><span style='color: #00134d;'>{}</span>&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;<b>Language:</b> \
        <span style='color: #00134d;'>{}</span>&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;<b>License: </b><span style='color: #00134d;'>{}</span></p>
        <br>
        <pre style='height: {}px; overflow-y: scroll; overflow-x: hidden; color: #d9d9d9;border: 1px solid #e6b800; padding: 10px'><code>{}</code></pre>
        <br>
        <hr>
        <br>
        """.format(repository, language, license, code_height, text_html)
    return results_html


def match_query(query, num_results=10):
    s = Search(using=es, index=os.environ.get("index"))
    s.query = Q("match", content=query)
    s = s[:num_results]
    response = s.execute()
    return response

def phrase_query(query, num_results=10):
    s = Search(using=es, index=os.environ.get("index"))
    s.query = Q("match_phrase", content=query)
    s = s[:num_results]
    response = s.execute()
    return response

def search(query, num_results=10):
    print(es.ping())
    if query.startswith('"') and query.endswith('"'):
        response = phrase_query(query[1:-1], num_results=num_results)
    else:
        response = match_query(query, num_results=num_results)
    results = [{"text": nh3.clean(html.escape(hit.content)), "repository": f"{hit.repository}/{hit.path}", "license": hit.license[0], "language": hit.language} for hit in response]
    return process_results(results, query)

description = """# <p style="text-align: center;"><span style='color: #e6b800;'>StarCoder:</span> Dataset Search ๐Ÿ” </p>
<span>When using <a href="https://huggingface.co/bigcode/large-model" style="color: #e6b800;">StarCoder</a> to generate code, it might produce exact copies of code in the pretraining dataset. \
In that case, the code license might have requirements to comply with. With this search tool, our aim is to help in identifying if the code belongs to an existing repository. For exact matches, enclose your query in double quotes.</span>"""

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)
css = ".generating {visibility: hidden}"

monospace_css = """
#q-input textarea {
    font-family: monospace, 'Consolas', Courier, monospace;
}
"""

css = monospace_css + ".gradio-container {color: black}"


if __name__ == "__main__":
    demo = gr.Blocks(
        theme=theme,
        css=css,
    )

    with demo:
        with gr.Row():
            gr.Markdown(value=description)
        with gr.Row():
            query = gr.Textbox(lines=5, placeholder="Type your query here...", label="Query", elem_id="q-input")
        with gr.Row():
            k = gr.Slider(1, 100, value=10, step=1, label="Max Results")
        with gr.Row():
            submit_btn = gr.Button("Submit")
        with gr.Row():
            results = gr.HTML(label="Results", value="")

        def submit(query, k, lang="en"):
            query = query.strip()
            if query is None or query == "":
                return "", ""
            return {
                results: search(query, k),
            }

        query.submit(fn=submit, inputs=[query, k], outputs=[results])
        submit_btn.click(submit, inputs=[query, k], outputs=[results])

    demo.launch(enable_queue=True, debug=True)