Vipitis commited on
Commit
2d141af
1 Parent(s): 980b6a3

refactor tree_utils

Browse files
Files changed (2) hide show
  1. app.py +11 -65
  2. tree_utils.py +59 -0
app.py CHANGED
@@ -6,6 +6,9 @@ import numpy as np
6
  import torch
7
  from threading import Thread
8
 
 
 
 
9
  def make_script(shader_code):
10
  # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
11
  script = ("""
@@ -295,18 +298,6 @@ new_shadertoy_code = """void mainImage( out vec4 fragColor, in vec2 fragCoord )
295
  fragColor = vec4(col,1.0);
296
  }"""
297
 
298
- passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
299
- single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
300
- # single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license?
301
- all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
302
- num_samples = len(all_single_passes)
303
-
304
- import tree_sitter
305
- from tree_sitter import Language, Parser
306
- Language.build_library("./build/my-languages.so", ['tree-sitter-glsl'])
307
- GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
308
- parser = Parser()
309
- parser.set_language(GLSL_LANGUAGE)
310
 
311
  def grab_sample(sample_idx):
312
  sample_pass = all_single_passes[sample_idx]
@@ -322,19 +313,6 @@ def grab_sample(sample_idx):
322
  # print(f"updating drop down to:{func_identifiers}")
323
  return sample_pass, sample_code, sample_title, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
324
 
325
-
326
- def _parse_functions(in_code):
327
- """
328
- returns all functions in the code as their actual nodes.
329
- includes any comment made directly after the function definition or diretly after #copilot trigger
330
- """
331
- tree = parser.parse(bytes(in_code, "utf8"))
332
- funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
333
-
334
- return funcs
335
-
336
- PIPE = None
337
-
338
  def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
339
  # if torch.cuda.is_available():
340
  # device = "cuda"
@@ -436,16 +414,6 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
436
 
437
  return altered_code
438
 
439
- def _line_chr2char(text, line_idx, chr_idx):
440
- """
441
- returns the character index at the given line and character index.
442
- """
443
- lines = text.split("\n")
444
- char_idx = 0
445
- for i in range(line_idx):
446
- char_idx += len(lines[i]) + 1
447
- char_idx += chr_idx
448
- return char_idx
449
 
450
  def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
451
  gen_kwargs = {}
@@ -455,34 +423,6 @@ def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_pe
455
  gen_kwargs["repetition_penalty"] = repetition_penalty
456
  return gen_kwargs
457
 
458
- def _grab_before_comments(func_node):
459
- """
460
- returns the comments that happen just before a function node
461
- """
462
- precomment = ""
463
- last_comment_line = 0
464
- for node in func_node.parent.children: #could you optimize where to iterated from? directon?
465
- if node.start_point[0] != last_comment_line + 1:
466
- precomment = ""
467
- if node.type == "comment":
468
- precomment += node.text.decode() + "\n"
469
- last_comment_line = node.start_point[0]
470
- elif node == func_node:
471
- return precomment
472
- return precomment
473
-
474
- def _get_docstrings(func_node):
475
- """
476
- returns the docstring of a function node
477
- """
478
- docstring = ""
479
- for node in func_node.child_by_field_name("body").children:
480
- if node.type == "comment" or node.type == "{":
481
- docstring += node.text.decode() + "\n"
482
- else:
483
- return docstring
484
- return docstring
485
-
486
  def alter_body(old_code, func_id, funcs_list: list, prompt, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
487
  """
488
  Replaces the body of a function with a generated one.
@@ -581,7 +521,7 @@ def construct_embed(source_url):
581
  with gr.Blocks() as site:
582
  top_md = gr.Markdown(intro_text)
583
  model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
584
- sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
585
  func_dropdown = gr.Dropdown(value=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None?
586
  prompt_text = gr.Textbox(value="the title used by the model has generation hint", label="prompt text", info="leave blank to skip", interactive=True)
587
  with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
@@ -644,7 +584,7 @@ with gr.Blocks() as site:
644
 
645
  model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
646
  sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
647
- gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, pipe], outputs=[sample_code])
648
  gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then(
649
  fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
650
  )
@@ -652,5 +592,11 @@ with gr.Blocks() as site:
652
  fn=make_iframe, inputs=[sample_code], outputs=[our_embed])
653
 
654
  if __name__ == "__main__": #works on huggingface?
 
 
 
 
 
 
655
  site.queue()
656
  site.launch()
 
6
  import torch
7
  from threading import Thread
8
 
9
+ from tree_utils import _parse_functions, _get_docstrings, _grab_before_comments, _line_chr2char
10
+ PIPE = None
11
+
12
  def make_script(shader_code):
13
  # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
14
  script = ("""
 
298
  fragColor = vec4(col,1.0);
299
  }"""
300
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  def grab_sample(sample_idx):
303
  sample_pass = all_single_passes[sample_idx]
 
313
  # print(f"updating drop down to:{func_identifiers}")
314
  return sample_pass, sample_code, sample_title, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
317
  # if torch.cuda.is_available():
318
  # device = "cuda"
 
414
 
415
  return altered_code
416
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
419
  gen_kwargs = {}
 
423
  gen_kwargs["repetition_penalty"] = repetition_penalty
424
  return gen_kwargs
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  def alter_body(old_code, func_id, funcs_list: list, prompt, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
427
  """
428
  Replaces the body of a function with a generated one.
 
521
  with gr.Blocks() as site:
522
  top_md = gr.Markdown(intro_text)
523
  model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
524
+ sample_idx = gr.Slider(minimum=0, maximum=10513, value=3211, label="pick sample from dataset", step=1.0)
525
  func_dropdown = gr.Dropdown(value=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None?
526
  prompt_text = gr.Textbox(value="the title used by the model has generation hint", label="prompt text", info="leave blank to skip", interactive=True)
527
  with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
 
584
 
585
  model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
586
  sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here?
587
+ gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code])
588
  gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then(
589
  fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
590
  )
 
592
  fn=make_iframe, inputs=[sample_code], outputs=[our_embed])
593
 
594
  if __name__ == "__main__": #works on huggingface?
595
+ passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
596
+ single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
597
+ # single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license?
598
+ all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
599
+ num_samples = len(all_single_passes)
600
+
601
  site.queue()
602
  site.launch()
tree_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tree_sitter
2
+ from tree_sitter import Language, Parser
3
+
4
+ Language.build_library("./build/my-languages.so", ['tree-sitter-glsl'])
5
+ GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
6
+ parser = Parser()
7
+ parser.set_language(GLSL_LANGUAGE)
8
+
9
+
10
+ def _parse_functions(in_code):
11
+ """
12
+ returns all functions in the code as their actual nodes.
13
+ includes any comment made directly after the function definition or diretly after #copilot trigger
14
+ """
15
+ tree = parser.parse(bytes(in_code, "utf8"))
16
+ funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
17
+
18
+ return funcs
19
+
20
+
21
+ def _get_docstrings(func_node):
22
+ """
23
+ returns the docstring of a function node
24
+ """
25
+ docstring = ""
26
+ for node in func_node.child_by_field_name("body").children:
27
+ if node.type == "comment" or node.type == "{":
28
+ docstring += node.text.decode() + "\n"
29
+ else:
30
+ return docstring
31
+ return docstring
32
+
33
+
34
+ def _grab_before_comments(func_node):
35
+ """
36
+ returns the comments that happen just before a function node
37
+ """
38
+ precomment = ""
39
+ last_comment_line = 0
40
+ for node in func_node.parent.children: #could you optimize where to iterated from? directon?
41
+ if node.start_point[0] != last_comment_line + 1:
42
+ precomment = ""
43
+ if node.type == "comment":
44
+ precomment += node.text.decode() + "\n"
45
+ last_comment_line = node.start_point[0]
46
+ elif node == func_node:
47
+ return precomment
48
+ return precomment
49
+
50
+ def _line_chr2char(text, line_idx, chr_idx):
51
+ """
52
+ returns the character index at the given line and character index.
53
+ """
54
+ lines = text.split("\n")
55
+ char_idx = 0
56
+ for i in range(line_idx):
57
+ char_idx += len(lines[i]) + 1
58
+ char_idx += chr_idx
59
+ return char_idx