Spaces:
Sleeping
Sleeping
File size: 6,428 Bytes
e1aa577 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import yaml
from easydict import EasyDict as edict
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOpenAI
from pathlib import Path
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.chat_models import AzureChatOpenAI
from langchain.chains import LLMChain
import logging
LLM_ENV = yaml.safe_load(open('config/llm_env.yml', 'r'))
class Color:
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
END = '\033[0m' # Reset to default color
def get_llm(config: dict):
"""
Returns the LLM model
:param config: dictionary with the configuration
:return: The llm model
"""
if 'temperature' not in config:
temperature = 0
else:
temperature = config['temperature']
if 'model_kwargs' in config:
model_kwargs = config['model_kwargs']
else:
model_kwargs = {}
if config['type'] == 'OpenAI':
if LLM_ENV['openai']['OPENAI_ORGANIZATION'] == '':
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', LLM_ENV['openai']['OPENAI_API_BASE']),
model_kwargs=model_kwargs)
else:
return ChatOpenAI(temperature=temperature, model_name=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['openai']['OPENAI_API_KEY']),
openai_api_base=config.get('openai_api_base', 'https://api.openai.com/v1'),
openai_organization=config.get('openai_organization', LLM_ENV['openai']['OPENAI_ORGANIZATION']),
model_kwargs=model_kwargs)
elif config['type'] == 'Azure':
return AzureChatOpenAI(temperature=temperature, azure_deployment=config['name'],
openai_api_key=config.get('openai_api_key', LLM_ENV['azure']['AZURE_OPENAI_API_KEY']),
azure_endpoint=config.get('azure_endpoint', LLM_ENV['azure']['AZURE_OPENAI_ENDPOINT']),
openai_api_version=config.get('openai_api_version', LLM_ENV['azure']['OPENAI_API_VERSION']))
elif config['type'] == 'Google':
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(temperature=temperature, model=config['name'],
google_api_key=LLM_ENV['google']['GOOGLE_API_KEY'],
model_kwargs=model_kwargs)
elif config['type'] == 'HuggingFacePipeline':
device = config.get('gpu_device', -1)
device_map = config.get('device_map', None)
return HuggingFacePipeline.from_model_id(
model_id=config['name'],
task="text-generation",
pipeline_kwargs={"max_new_tokens": config['max_new_tokens']},
device=device,
device_map=device_map
)
else:
raise NotImplementedError("LLM not implemented")
def load_yaml(yaml_path: str, as_edict: bool = True) -> edict:
"""
Reads the yaml file and enrich it with more vales.
:param yaml_path: The path to the yaml file
:param as_edict: If True, returns an EasyDict configuration
:return: An EasyDict configuration
"""
with open(yaml_path, 'r') as file:
yaml_data = yaml.safe_load(file)
if 'meta_prompts' in yaml_data.keys() and 'folder' in yaml_data['meta_prompts']:
yaml_data['meta_prompts']['folder'] = Path(yaml_data['meta_prompts']['folder'])
if as_edict:
yaml_data = edict(yaml_data)
return yaml_data
def load_prompt(prompt_path: str) -> PromptTemplate:
"""
Reads and returns the contents of a prompt file.
:param prompt_path: The path to the prompt file
"""
with open(prompt_path, 'r') as file:
prompt = file.read().rstrip()
return PromptTemplate.from_template(prompt)
def validate_generation_config(base_config, generation_config):
if "annotator" not in generation_config:
raise Exception("Generation config must contain an empty annotator.")
if "label_schema" not in generation_config.dataset or \
base_config.dataset.label_schema != generation_config.dataset.label_schema:
raise Exception("Generation label schema must match the basic config.")
def modify_input_for_ranker(config, task_description, initial_prompt):
modifiers_config = yaml.safe_load(open('prompts/modifiers/modifiers.yml', 'r'))
task_desc_setup = load_prompt(modifiers_config['ranker']['task_desc_mod'])
init_prompt_setup = load_prompt(modifiers_config['ranker']['prompt_mod'])
llm = get_llm(config.llm)
task_llm_chain = LLMChain(llm=llm, prompt=task_desc_setup)
task_result = task_llm_chain(
{"task_description": task_description})
mod_task_desc = task_result['text']
logging.info(f"Task description modified for ranking to: \n{mod_task_desc}")
prompt_llm_chain = LLMChain(llm=llm, prompt=init_prompt_setup)
prompt_result = prompt_llm_chain({"prompt": initial_prompt, 'label_schema': config.dataset.label_schema})
mod_prompt = prompt_result['text']
logging.info(f"Initial prompt modified for ranking to: \n{mod_prompt}")
return mod_prompt, mod_task_desc
def override_config(override_config_file, config_file='config/config_default.yml'):
"""
Override the default configuration file with the override configuration file
:param config_file: The default configuration file
:param override_config_file: The override configuration file
"""
def override_dict(config_dict, override_config_dict):
for key, value in override_config_dict.items():
if isinstance(value, dict):
if key not in config_dict:
config_dict[key] = value
else:
override_dict(config_dict[key], value)
else:
config_dict[key] = value
return config_dict
default_config_dict = load_yaml(config_file, as_edict=False)
override_config_dict = load_yaml(override_config_file, as_edict=False)
config_dict = override_dict(default_config_dict, override_config_dict)
return edict(config_dict)
|