ARCQUB commited on
Commit
bc6f92e
·
verified ·
1 Parent(s): 4150544

Update models/gpt_4.1.py

Browse files
Files changed (1) hide show
  1. models/gpt_4.1.py +116 -0
models/gpt_4.1.py CHANGED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gpt4o_pix2struct_ocr.py
2
+
3
+ import os
4
+ import json
5
+ import base64
6
+ from PIL import Image
7
+ import openai
8
+ import torch
9
+
10
+ model = "gpt-4.1"
11
+
12
+ pix2struct_model = None
13
+ processor = None
14
+
15
+
16
+ def load_prompt(prompt_file="prompts/prompt.txt"):
17
+ #with open(prompt_file, "r", encoding="utf-8") as f:
18
+ # return f.read().strip()
19
+ return os.getenv("PROMPT_TEXT", "⚠️ PROMPT_TEXT not found in secrets.")
20
+
21
+
22
+ def try_extract_json(text):
23
+ try:
24
+ return json.loads(text)
25
+ except json.JSONDecodeError:
26
+ start = text.find('{')
27
+ if start == -1:
28
+ return None
29
+ brace_count = 0
30
+ json_candidate = ''
31
+ for i in range(start, len(text)):
32
+ if text[i] == '{':
33
+ brace_count += 1
34
+ elif text[i] == '}':
35
+ brace_count -= 1
36
+ json_candidate += text[i]
37
+ if brace_count == 0 and json_candidate.strip():
38
+ break
39
+ try:
40
+ return json.loads(json_candidate)
41
+ except json.JSONDecodeError:
42
+ return None
43
+
44
+
45
+ def encode_image_base64(image: Image.Image):
46
+ from io import BytesIO
47
+ buffer = BytesIO()
48
+ image.save(buffer, format="JPEG")
49
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
50
+
51
+
52
+ def extract_all_text_pix2struct(image: Image.Image):
53
+ global pix2struct_model, processor
54
+
55
+ # Lazy-load the Pix2Struct model
56
+ if processor is None or pix2struct_model is None:
57
+ from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
58
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
59
+ pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(
60
+ "google/pix2struct-textcaps-base"
61
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
62
+
63
+ inputs = processor(images=image, return_tensors="pt").to(pix2struct_model.device)
64
+ predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
65
+ output_text = processor.decode(predictions[0], skip_special_tokens=True)
66
+ return output_text.strip()
67
+
68
+
69
+ def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
70
+ if not ocr_text:
71
+ return json_data
72
+
73
+ def guess_name_fallback(obj):
74
+ if not obj.get("name") or obj["name"].strip() == "":
75
+ obj["name"] = "(label unknown)"
76
+
77
+ for evt in json_data.get("events", []):
78
+ guess_name_fallback(evt)
79
+
80
+ for gw in json_data.get("gateways", []):
81
+ guess_name_fallback(gw)
82
+
83
+ return json_data
84
+
85
+
86
+ def run_model(image: Image.Image, api_key: str = None):
87
+ prompt_text = load_prompt()
88
+ encoded_image = encode_image_base64(image)
89
+
90
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
91
+ if not api_key:
92
+ return {"json": None, "raw": "⚠️ API key is missing. Please set it as a secret in your Space or upload it as a file."}
93
+
94
+ client = openai.OpenAI(api_key=api_key)
95
+ response = client.chat.completions.create(
96
+ model=model,
97
+ messages=[
98
+ {
99
+ "role": "user",
100
+ "content": [
101
+ {"type": "text", "text": prompt_text},
102
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
103
+ ]
104
+ }
105
+ ],
106
+ max_tokens=5000
107
+ )
108
+
109
+ output_text = response.choices[0].message.content.strip()
110
+ parsed_json = try_extract_json(output_text)
111
+
112
+ # Use Pix2Struct OCR enrichment
113
+ full_ocr_text = extract_all_text_pix2struct(image)
114
+ parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
115
+
116
+ return {"json": parsed_json, "raw": output_text}