cassanof commited on
Commit
35a3912
·
verified ·
1 Parent(s): bcbf898

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +53 -1
README.md CHANGED
@@ -85,4 +85,56 @@ def gen(old, new, max_new_tokens=200, temperature=0.45, top_p=0.90):
85
  return [tokenizer.decode(out[len(toks[0]):], skip_special_tokens=True) for out in outs]
86
  ```
87
 
88
- use the "gen" function with the old and new code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return [tokenizer.decode(out[len(toks[0]):], skip_special_tokens=True) for out in outs]
86
  ```
87
 
88
+ use the "gen" function with the old and new code
89
+
90
+ # Example:
91
+ ```py
92
+ - import datasets
93
+ - from pathlib import Path
94
+ from code_editing.models import CodeLlamaEditModel, LlamaChatModel, EditModel, EditCommand, ChatAdaptorEditModel, OctoCoderChatModel, codellama_edit_prompt_diff, apply_rel_diff_trim, OpenAIChatModel, StarCoderCommitEditModel
95
+ from code_editing.humanevalpack import batch_prompts_from_example
96
+ from code_editing.utils import gunzip_json_write
97
+ from typing import List, Callable
98
+ from tqdm import tqdm
99
+
100
+
101
+ # NOTE: this is the factory for each model type. to add a new model type, add a new case here
102
+ # and implement it in models.py. Also, add a new case in the argument parser below.
103
+ - def model_factory(model_type: str, quantize=False, num_gpus=1) -> Callable[[str], EditModel]:
104
+ + def model_factory(
105
+ + model_type: str,
106
+ + quantize=False,
107
+ + num_gpus=1,
108
+ + system_supported=True,
109
+ + ) -> Callable[[str], EditModel]:
110
+ if model_type == "codellama" or model_type == "deepseek":
111
+ return CodeLlamaEditModel
112
+ elif model_type == "starcoder":
113
+ return StarCoderCommitEditModel
114
+ elif model_type == "codellama-diff":
115
+ return (lambda path: CodeLlamaEditModel(path, prompt_format=codellama_edit_prompt_diff, post_process=apply_rel_diff_trim))
116
+ elif model_type == "openai":
117
+ return (lambda path: ChatAdaptorEditModel(OpenAIChatModel(path)))
118
+ elif model_type == "codellama-chat":
119
+ - return (lambda path: ChatAdaptorEditModel(LlamaChatModel(path, quantization=quantize, num_gpus=num_gpus)))
120
+ + return (lambda path: ChatAdaptorEditModel(LlamaChatModel(path, quantization=quantize, num_gpus=num_gpus, system_supported=system_supported)))
121
+ elif model_type == "octocoder":
122
+ return (lambda path: ChatAdaptorEditModel(OctoCoderChatModel(path, quantization=quantize, num_gpus=num_gpus)))
123
+ else:
124
+ raise ValueError(f"Unknown model type: {model_type}")
125
+
126
+ def complete_problem(example: EditCommand, model: EditModel, batch_size: int, completion_limit: int, **kwargs) -> List[str]:
127
+ batches = batch_prompts_from_example(example, batch_size, completion_limit)
128
+
129
+ completions = []
130
+ for batch in batches:
131
+ resps = model.generate(batch, **kwargs)
132
+ for resp in resps:
133
+ completions.append(resp["content"])
134
+
135
+ return completions
136
+ ```
137
+ Produced:
138
+ ```
139
+ Add system_supported argument to model_factory
140
+ ```