File size: 7,539 Bytes
87c3140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os
import sys
import inspect
import json
from json import JSONDecodeError
import tiktoken
import random 
import google.generativeai as palm

currentdir = os.path.dirname(os.path.abspath(
    inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)

from prompt_catalog import PromptCatalog
from general_utils import num_tokens_from_string

"""
DEPRECATED:
    Safety setting regularly block a response, so set to 4 to disable

    class HarmBlockThreshold(Enum):
        HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0
        BLOCK_LOW_AND_ABOVE = 1
        BLOCK_MEDIUM_AND_ABOVE = 2
        BLOCK_ONLY_HIGH = 3
        BLOCK_NONE = 4
"""

SAFETY_SETTINGS = [
    {
        "category": "HARM_CATEGORY_DEROGATORY",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_TOXICITY",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_VIOLENCE",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_SEXUAL",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_MEDICAL",
        "threshold": "BLOCK_NONE",
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS",
        "threshold": "BLOCK_NONE",
    },
]

PALM_SETTINGS = {
    'model': 'models/text-bison-001',
    'temperature': 0,
    'candidate_count': 1,
    'top_k': 40,
    'top_p': 0.95,
    'max_output_tokens': 8000,
    'stop_sequences': [],
    'safety_settings': SAFETY_SETTINGS,
}

PALM_SETTINGS_REDO = {
    'model': 'models/text-bison-001',
    'temperature': 0.05,
    'candidate_count': 1,
    'top_k': 40,
    'top_p': 0.95,
    'max_output_tokens': 8000,
    'stop_sequences': [],
    'safety_settings': SAFETY_SETTINGS,
}

def OCR_to_dict_PaLM(logger, OCR, prompt_version, VVE):
    try:
        logger.info(f'Length of OCR raw -- {len(OCR)}')
    except:
        print(f'Length of OCR raw -- {len(OCR)}')
        
    # prompt = PROMPT_PaLM_UMICH_skeleton_all_asia(OCR, in_list, out_list) # must provide examples to PaLM differently than for chatGPT, at least 2 examples
    Prompt = PromptCatalog(OCR) 
    if prompt_version in ['prompt_v2_palm2']:
        version = 'v2'
        prompt = Prompt.prompt_v2_palm2(OCR)
    
    elif prompt_version in ['prompt_v1_palm2',]:
        version = 'v1'
        # create input: output: for PaLM
        # Find a similar example from the domain knowledge
        domain_knowledge_example = VVE.query_db(OCR, 4)
        similarity= VVE.get_similarity()
        domain_knowledge_example_string = json.dumps(domain_knowledge_example)
        in_list, out_list = create_OCR_analog_for_input(domain_knowledge_example)
        prompt = Prompt.prompt_v1_palm2(in_list, out_list, OCR)

    elif prompt_version in ['prompt_v1_palm2_noDomainKnowledge',]:
        version = 'v1'
        prompt = Prompt.prompt_v1_palm2_noDomainKnowledge(OCR)
    else:
        version = 'custom'
        prompt, n_fields, xlsx_headers = Prompt.prompt_v2_custom(prompt_version, OCR=OCR, is_palm=True)
        # raise

    nt = num_tokens_from_string(prompt, "cl100k_base")
    # try:
    logger.info(f'Prompt token length --- {nt}')
    # except:
        # print(f'Prompt token length --- {nt}')

    do_use_SOP = False ########

    if do_use_SOP:
        '''TODO: Check back later to see if LangChain will support PaLM'''
        # logger.info(f'Waiting for PaLM API call --- Using StructuredOutputParser')
        # response = structured_output_parser(OCR, prompt, logger)
        # return response['Dictionary']
        pass

    else:
        # try:
        logger.info(f'Waiting for PaLM 2 API call')
        # except:
            # print(f'Waiting for PaLM 2 API call --- Content')

        # safety_thresh = 4
        # PaLM_settings = {'model': 'models/text-bison-001','temperature': 0,'candidate_count': 1,'top_k': 40,'top_p': 0.95,'max_output_tokens': 8000,'stop_sequences': [],
                        #  'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":safety_thresh},{"category":"HARM_CATEGORY_TOXICITY","threshold":safety_thresh},{"category":"HARM_CATEGORY_VIOLENCE","threshold":safety_thresh},{"category":"HARM_CATEGORY_SEXUAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_MEDICAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_DANGEROUS","threshold":safety_thresh}],}
        response = palm.generate_text(prompt=prompt, **PALM_SETTINGS)


        if response and response.result:
            if isinstance(response.result, (str, bytes)):
                response_valid = check_and_redo_JSON(response, logger, version)
            else:
                response_valid = {}
        else:
            response_valid = {}

        logger.info(f'Candidate JSON\n{response.result}')
        return response_valid, nt

def check_and_redo_JSON(response, logger, version):
    try:
        response_valid = json.loads(response.result)
        logger.info(f'Response --- First call passed')
        return response_valid
    except JSONDecodeError:

        try:
            response_valid = json.loads(response.result.strip('```').replace('json\n', '', 1).replace('json', '', 1))
            logger.info(f'Response --- Manual removal of ```json succeeded')
            return response_valid
        except:
            logger.info(f'Response --- First call failed. Redo...')
            Prompt = PromptCatalog() 
            if version == 'v1':
                prompt_redo = Prompt.prompt_palm_redo_v1(response.result)
            elif version == 'v2':
                prompt_redo = Prompt.prompt_palm_redo_v2(response.result)
            elif version == 'custom':
                prompt_redo = Prompt.prompt_v2_custom_redo(response.result, is_palm=True)


            # prompt_redo = PROMPT_PaLM_Redo(response.result)
            try:
                response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS)
                response_valid = json.loads(response.result)
                logger.info(f'Response --- Second call passed')
                return response_valid
            except JSONDecodeError:
                logger.info(f'Response --- Second call failed. Final redo. Temperature changed to 0.05')
                try:
                    response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS_REDO)
                    response_valid = json.loads(response.result)
                    logger.info(f'Response --- Third call passed')
                    return response_valid
                except JSONDecodeError:
                    return None
            

def create_OCR_analog_for_input(domain_knowledge_example):
    in_list = []
    out_list = []
    # Iterate over the domain_knowledge_example (list of dictionaries)
    for row_dict in domain_knowledge_example:
        # Convert the dictionary to a JSON string and add it to the out_list
        domain_knowledge_example_string = json.dumps(row_dict)
        out_list.append(domain_knowledge_example_string)

        # Create a single string from all values in the row_dict
        row_text = '||'.join(str(v) for v in row_dict.values())

        # Split the row text by '||', shuffle the parts, and then re-join with a single space
        parts = row_text.split('||')
        random.shuffle(parts)
        shuffled_text = ' '.join(parts)

        # Add the shuffled_text to the in_list
        in_list.append(shuffled_text)
    return in_list, out_list


def strip_problematic_chars(s):
    return ''.join(c for c in s if c.isprintable())