import json import logging import os import re import string import html import gradio as gr import nh3 from elasticsearch import Elasticsearch from elasticsearch_dsl import Search, Q # es = Elasticsearch(os.environ.get("host"), timeout=100, http_compress=True, maxsize=1000) es = Elasticsearch(os.environ.get("host"), http_compress=True, timeout=200) def mark_tokens_bold(text, tokens): # if query.startswith('"') and query.endswith('"'): # tokens = query[1:-1] # else: # tokens = query.split(" ") for token in tokens: pattern = re.escape(token) #r"\b" + re.escape(token) + r"\b" text = re.sub(pattern, "" + token + "", text) return text def process_results(results, query): if len(results) == 0: return """

No results retrieved.



""" results_html = "" for result in results: text_html = result["text"] if query.startswith('"') and query.endswith('"'): text_html = mark_tokens_bold(text_html, query[1:-1].split(" ")) else: text_html = mark_tokens_bold(text_html, query.split(" ")) repository = result["repository"] path = result["path"] license = result["license"] language = result["language"] code_height = min(600, len(text_html.split('\n')) * 20) # limit to maximum height of 600px results_html += """\

Source: {}   |   Language: \ {}   |   License: {}

{}

""".format(repository, path, f"{repository}/{path}", language, license, code_height, text_html) return results_html def match_query(query, num_results=10): s = Search(using=es, index=os.environ.get("index")) s.query = Q("match", content=query) s = s[:num_results] response = s.execute() return response def phrase_query(query, num_results=10): s = Search(using=es, index=os.environ.get("index")) s.query = Q("match_phrase", content=query) s = s[:num_results] response = s.execute() return response def search(query, num_results=10): print(es.ping()) if query.startswith('"') and query.endswith('"'): response = phrase_query(query[1:-1], num_results=num_results) else: response = match_query(query, num_results=num_results) results = [{"text": nh3.clean(html.escape(hit.content)), "repository": hit.repository, "path":hit.path, "license": hit.license[0], "language": hit.language} for hit in response] return process_results(results, query) description = """#

StarCoder: Dataset Search 🔍

When using StarCoder to generate code, it might produce exact copies of code in the pretraining dataset. \ In that case, the code license might have requirements to comply with. With this search tool, our aim is to help in identifying if the code belongs to an existing repository. For exact matches, enclose your query in double quotes. This first iteration of the search tool will truncated queries to 200 characters, so as not to overwhelm the server this is currently running on, as we work on scaling it up.""" theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[ gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif", ], ) css = ".generating {visibility: hidden}" monospace_css = """ #q-input textarea { font-family: monospace, 'Consolas', Courier, monospace; } """ css = monospace_css + ".gradio-container {color: black}" if __name__ == "__main__": demo = gr.Blocks( theme=theme, css=css, ) with demo: with gr.Row(): gr.Markdown(value=description) with gr.Row(): query = gr.Textbox(lines=5, placeholder="Type your query here...", label="Query", elem_id="q-input") with gr.Row(): k = gr.Slider(1, 100, value=10, step=1, label="Max Results") with gr.Row(): submit_btn = gr.Button("Submit") with gr.Row(): results = gr.HTML(label="Results", value="") def submit(query, k, lang="en"): query = query.strip() if query is None or query == "": return "", "" return { results: search(query, k), } query.submit(fn=submit, inputs=[query, k], outputs=[results]) submit_btn.click(submit, inputs=[query, k], outputs=[results]) demo.launch(enable_queue=True, debug=True)