Update helper/utils.py
Browse files- helper/utils.py +59 -0
helper/utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
2 |
from datetime import datetime
|
|
|
3 |
from typing import Any, Dict, List, Tuple, Union
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
@@ -205,6 +207,63 @@ def call_llama(prompt: str) -> str:
|
|
205 |
return response.choices[0].message.content
|
206 |
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
def quantize_to_kbit(arr: Union[np.ndarray, Any], k: int = 16) -> np.ndarray:
|
209 |
"""Converts an array to a k-bit representation by normalizing and scaling its values.
|
210 |
|
|
|
1 |
import os
|
2 |
from datetime import datetime
|
3 |
+
import json
|
4 |
from typing import Any, Dict, List, Tuple, Union
|
5 |
+
import requests
|
6 |
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
|
|
207 |
return response.choices[0].message.content
|
208 |
|
209 |
|
210 |
+
def call_llama2(prompt: str, max_new_tokens: int = 50, temperature: float = 0.9) -> str:
|
211 |
+
"""
|
212 |
+
Calls the Llama API to generate text based on a given prompt, controlling the length and randomness.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
prompt (str): The prompt text to send to the Llama model for text generation.
|
216 |
+
max_new_tokens (int, optional): The maximum number of tokens that the model should generate. Defaults to 50.
|
217 |
+
temperature (float, optional): Controls the randomness of the output. Lower values make the model more deterministic.
|
218 |
+
A higher value increases randomness. Defaults to 0.9.
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
str: The generated text response from the Llama model.
|
222 |
+
|
223 |
+
Raises:
|
224 |
+
Exception: If the API call fails and returns a non-200 status code, it raises an exception with the error details.
|
225 |
+
"""
|
226 |
+
# API endpoint for the Llama model
|
227 |
+
api_url = "https://v6rkdcyir7.execute-api.us-east-1.amazonaws.com/beta"
|
228 |
+
|
229 |
+
# Configuration for the request body
|
230 |
+
json_body = {
|
231 |
+
"body": {
|
232 |
+
"inputs": f"<s>[INST] {prompt} [/INST]",
|
233 |
+
"parameters": {
|
234 |
+
"max_new_tokens": max_new_tokens,
|
235 |
+
"top_p": 0.9, # Fixed probability cutoff to select tokens with cumulative probability above this threshold
|
236 |
+
"temperature": temperature
|
237 |
+
}
|
238 |
+
}
|
239 |
+
}
|
240 |
+
|
241 |
+
# Headers to indicate that the payload is JSON
|
242 |
+
headers = {"Content-Type": "application/json"}
|
243 |
+
|
244 |
+
# Perform the POST request to the Llama API
|
245 |
+
response = requests.post(api_url, headers=headers, json=json_body)
|
246 |
+
|
247 |
+
# Parse the JSON response
|
248 |
+
response_body = response.json()['body']
|
249 |
+
|
250 |
+
# Convert the string response to a JSON object
|
251 |
+
body_list = json.loads(response_body)
|
252 |
+
|
253 |
+
# Extract the 'generated_text' from the first item in the list
|
254 |
+
generated_text = body_list[0]['generated_text']
|
255 |
+
|
256 |
+
# Separate the answer from the instruction
|
257 |
+
answer = generated_text.split("[/INST]")[-1].strip()
|
258 |
+
|
259 |
+
# Check the status code of the response
|
260 |
+
if response.status_code == 200:
|
261 |
+
return answer # Return the text generated by the model
|
262 |
+
else:
|
263 |
+
# Raise an exception if the API did not succeed
|
264 |
+
raise Exception(f"Error calling Llama API: {response.status_code}")
|
265 |
+
|
266 |
+
|
267 |
def quantize_to_kbit(arr: Union[np.ndarray, Any], k: int = 16) -> np.ndarray:
|
268 |
"""Converts an array to a k-bit representation by normalizing and scaling its values.
|
269 |
|