# coding: utf-8 # Copyright (C) 2023, [Breezedeus](https://github.com/breezedeus). # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import os import sys import logging from typing import List import yaml import gradio as gr from PIL import Image import numpy as np from datasets import load_dataset import chromadb from chromadb import Settings from coin_clip.utils import resize_img from coin_clip.chroma_embedding import ChromaEmbeddingFunction from coin_clip.detect import Detector logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) env = os.environ.get('COIN_ENV', 'local') if env == 'hf': config_fp = 'hf_config.yaml' hf_token = os.environ.get('HF_TOKEN') else: config_fp = 'local_config.yaml' logger.info(f'Use config file: {config_fp}') total_config = yaml.safe_load(open(config_fp)) DETECTOR = Detector( model_name=total_config['detector']['model_name'], device=total_config['detector']['device'], ) # USE_REMOVE_BG = total_config['use_remove_bg'] RESIZED_TO_BEFORE_DETECT = total_config['detector'].get('resized_to', 300) def prepare_chromadb(): if env == 'local': return from huggingface_hub import snapshot_download snapshot_download( repo_type='model', repo_id='breezedeus/usa-coins-chromadb', local_dir='./', token=hf_token, ) def _load_dataset(data_path): logger.info('Load dataset from %s', data_path) if env == 'hf': dataset = load_dataset(data_path, split='train', token=hf_token) else: dataset = load_dataset("imagefolder", data_dir=data_path, split='train') return dataset def detect(images): outs = [] for idx, img in enumerate(images): img = resize_img(img, RESIZED_TO_BEFORE_DETECT) out = DETECTOR.detect(np.array(img)) if not out: out = {'position': None, 'scores': 0.0} else: out = out[0] out.pop('label') out['position'] = out.pop('box') out['from_image_idx'] = idx outs.append(out) box_images = [] for out, img in zip(outs, images): if out['position'] is None: box_images.append(None) else: # box 比例值转化为绝对位置值 w, h = img.size box = out['position'] box = (int(box[0] * w), int(box[1] * h), int(box[2] * w), int(box[3] * h)) box_images.append(img.crop(box)) return outs, box_images def load_chroma_db(db_dir, collection_name, model_name, device='cpu'): logger.info('Load chroma db from %s', db_dir) client = chromadb.PersistentClient( path=db_dir, settings=Settings(anonymized_telemetry=False) ) embedding_function = ChromaEmbeddingFunction(model_name, device) collection = client.get_collection( name=collection_name, embedding_function=embedding_function, ) return collection def retrieve(query_image: Image.Image, collection, top_k=20) -> List[Image.Image]: query_image = np.array(query_image) retrieved = collection.query( query_images=[query_image], include=['metadatas', 'distances'], n_results=top_k, ) logger.info('retrieved ids: %s', retrieved['ids'][0]) logger.info('retrieved distances: %s', retrieved['distances'][0]) return [ds_dict[id]['image'] for id in retrieved['ids'][0]] dataset = _load_dataset(**total_config['dataset']) ds_dict = {_d['id']: _d for _d in dataset} prepare_chromadb() cc_collection = load_chroma_db(**total_config['coin_clip_db']) clip_collection = load_chroma_db(**total_config['clip_db']) def search(image_file: Image.Image): images = [image_file.convert('RGB')] detected_outs, box_images = detect(images) box_images = [img for img in box_images if img is not None] if len(box_images) == 0: return [ gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), ] box_image = box_images[0] # breakpoint() cc_results = retrieve(box_image, cc_collection, top_k=30) clip_results = retrieve(box_image, clip_collection, top_k=30) return [ gr.update(value=box_image, visible=True), gr.update(visible=False), gr.update(value=cc_results, visible=True), gr.update(value=clip_results, visible=True), ] def main(): title = 'USA Coin Retrieval by' # desc = ( # '
Coin-CLIP: ' # 'Model, ' # 'Github; ' # 'Author: Breezedeus , ' # 'Github
' # ) desc = """