rishiraj commited on
Commit
74282da
1 Parent(s): 7b86364

Create oai/oai_extractor.py

Browse files
Files changed (1) hide show
  1. oai/oai_extractor.py +78 -0
oai/oai_extractor.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Optional
2
+ from indexify_extractor_sdk import Content, Extractor, Feature
3
+ from pydantic import BaseModel, Field
4
+ import os
5
+ import base64
6
+ from openai import OpenAI
7
+ from pdf2image import convert_from_path
8
+ import tempfile
9
+ import mimetypes
10
+
11
+ class OAIExtractorConfig(BaseModel):
12
+ model_name: Optional[str] = Field(default='gpt-3.5-turbo')
13
+ key: Optional[str] = Field(default=None)
14
+ prompt: str = Field(default='You are a helpful assistant.')
15
+ query: Optional[str] = Field(default=None)
16
+
17
+ class OAIExtractor(Extractor):
18
+ name = "tensorlake/openai"
19
+ description = "An extractor that let's you use LLMs from OpenAI."
20
+ system_dependencies = []
21
+ input_mime_types = ["text/plain", "application/pdf", "image/jpeg", "image/png"]
22
+
23
+ def __init__(self):
24
+ super(OAIExtractor, self).__init__()
25
+
26
+ def extract(self, content: Content, params: OAIExtractorConfig) -> List[Union[Feature, Content]]:
27
+ contents = []
28
+ model_name = params.model_name
29
+ key = params.key
30
+ prompt = params.prompt
31
+ query = params.query
32
+
33
+ if content.content_type in ["application/pdf"]:
34
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
35
+ temp_file.write(content.data)
36
+ file_path = temp_file.name
37
+ images = convert_from_path(file_path)
38
+ image_files = []
39
+ for i in range(len(images)):
40
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_image_file:
41
+ images[i].save(temp_image_file.name, 'JPEG')
42
+ image_files.append(temp_image_file.name)
43
+ elif content.content_type in ["image/jpeg", "image/png"]:
44
+ image_files = []
45
+ suffix = mimetypes.guess_extension(content.content_type)
46
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_image_file:
47
+ temp_image_file.write(content.data)
48
+ file_path = temp_image_file.name
49
+ image_files.append(file_path)
50
+ else:
51
+ text = content.data.decode("utf-8")
52
+ if query is None:
53
+ query = text
54
+ file_path = None
55
+
56
+ def encode_image(image_path):
57
+ with open(image_path, "rb") as image_file:
58
+ return base64.b64encode(image_file.read()).decode('utf-8')
59
+
60
+ if ('OPENAI_API_KEY' not in os.environ) and (key is None):
61
+ response_content = "The OPENAI_API_KEY environment variable is not present."
62
+ else:
63
+ if ('OPENAI_API_KEY' in os.environ) and (key is None):
64
+ client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
65
+ else:
66
+ client = OpenAI(api_key=key)
67
+ if file_path:
68
+ encoded_images = [encode_image(image_path) for image_path in image_files]
69
+ messages_content = [ { "role": "user", "content": [ { "type": "text", "text": prompt, } ] + [ { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{encoded_image}", }, } for encoded_image in encoded_images ], } ]
70
+ else:
71
+ messages_content = [ {"role": "system", "content": prompt}, {"role": "user", "content": query} ]
72
+
73
+ response = client.chat.completions.create( model=model_name, messages=messages_content )
74
+ response_content = response.choices[0].message.content
75
+
76
+ contents.append(Content.from_text(response_content))
77
+
78
+ return contents