Update app.py
Browse files
app.py
CHANGED
@@ -37,10 +37,27 @@ def validate_inputs(gen_prompt, max_tokens, temperature, seed):
|
|
37 |
|
38 |
def generate_code(gen_prompt, max_tokens, temperature=0.6, seed=42):
|
39 |
validate_inputs(gen_prompt, max_tokens, temperature, seed)
|
40 |
-
|
41 |
-
#
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
|
46 |
def save_to_text_file(output_text):
|
|
|
37 |
|
38 |
def generate_code(gen_prompt, max_tokens, temperature=0.6, seed=42):
|
39 |
validate_inputs(gen_prompt, max_tokens, temperature, seed)
|
40 |
+
|
41 |
+
# Encode the input prompt
|
42 |
+
input_ids = tokenizer.encode(gen_prompt, return_tensors="pt")
|
43 |
+
|
44 |
+
# Set seed for reproducibility
|
45 |
+
set_seed(seed)
|
46 |
+
|
47 |
+
# Generate code tokens
|
48 |
+
output = model.generate(
|
49 |
+
input_ids,
|
50 |
+
max_length=max_tokens + input_ids.shape[-1],
|
51 |
+
temperature=temperature,
|
52 |
+
pad_token_id=tokenizer.eos_token_id,
|
53 |
+
num_return_sequences=1
|
54 |
+
)
|
55 |
+
|
56 |
+
# Decode the generated tokens into Python code
|
57 |
+
generated_code = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
58 |
+
|
59 |
+
return generated_code
|
60 |
+
|
61 |
|
62 |
|
63 |
def save_to_text_file(output_text):
|