import streamlit as st from transformers import CLIPModel, CLIPProcessor, pipeline import torch from PIL import Image ################################# #### FUNCTIONS def load_clip(model_size='large'): if model_size == 'base': MODEL_name = 'openai/clip-vit-base-patch32' elif model_size == 'large': MODEL_name = 'openai/clip-vit-large-patch14' model = CLIPModel.from_pretrained(MODEL_name) processor = CLIPProcessor.from_pretrained(MODEL_name) return processor, model def inference_clip(options, image, processor, model): inputs = processor(text= options, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) #logits_per_text = outputs.logits_per_text logits_per_image = outputs.logits_per_image # this is the image-text similarity score probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities max_prob_idx = torch.argmax(probs) max_prob_option = options[max_prob_idx] max_prob = probs[max_prob_idx].item() return max_prob_option ################################# #### LAYOUT #CLIP_large = load_clip(model_size='large') model_name = "openai/clip-vit-large-patch14-336" classifier = pipeline("zero-shot-image-classification", model = model_name) #### Loading picture picture_file = st.file_uploader("Picture :", type=["jpg", "jpeg", "png"]) if picture_file is not None: image = Image.open(picture_file) st.image(image, caption='Please upload an image of the damage', use_column_width=True) col_l, col_r = st.columns(2) #image with col_l: default_options = ['black', 'white', 'gray', 'red', 'blue', 'silver', 'red', 'brown', 'green', 'orange', 'beige', 'pruple', 'gold', 'yellow'] options = st.text_input(label="Please enter the classes", value=default_options) #options = list(options) # button to launch compute if st.button("Compute"): #clip_processor, clip_model = load_clip(model_size='large') #result = inference_clip(options = options, image = image, processor=clip_processor, model=clip_model) scores = classifier(image, candidate_labels = options) with col_r: #st.write(result) st.dataframe(scores)