Spaces:
Running
Running
import base64 | |
import os | |
import spaces | |
def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""): | |
# 执行OCR | |
try: | |
if got_mode == "plain texts OCR": | |
res = model.chat(tokenizer, image_path, ocr_type="ocr") | |
return res, None | |
elif got_mode == "format texts OCR": | |
result_path = f"{os.path.splitext(image_path)[0]}_result.html" | |
res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) | |
elif got_mode == "plain multi-crop OCR": | |
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr") | |
return res, None | |
elif got_mode == "format multi-crop OCR": | |
result_path = f"{os.path.splitext(image_path)[0]}_result.html" | |
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) | |
elif got_mode == "plain fine-grained OCR": | |
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color) | |
return res, None | |
elif got_mode == "format fine-grained OCR": | |
result_path = f"{os.path.splitext(image_path)[0]}_result.html" | |
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path) | |
# 处理格式化结果 | |
if "format" in got_mode and os.path.exists(result_path): | |
with open(result_path, "r") as f: | |
html_content = f.read() | |
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8") | |
return res, encoded_html | |
else: | |
return res, None | |
except Exception as e: | |
return f"错误: {str(e)}", None | |
# 使用示例 | |
if __name__ == "__main__": | |
import torch | |
from transformers import AutoConfig, AutoModel, AutoTokenizer | |
# 初始化模型和分词器 | |
model_name = "stepfun-ai/GOT-OCR2_0" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True) | |
model = model.eval().to(device) | |
model.config.pad_token_id = tokenizer.eos_token_id | |
image_path = "path/to/your/image.png" | |
result, html = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR") | |
print("OCR结果:", result) | |
if html: | |
print("HTML结果可用") | |