File size: 7,688 Bytes
a5316e5
 
d563836
 
544f914
a5316e5
5f4434d
8da738a
7145ecb
544f914
 
 
6176ef8
f37b5da
d563836
c8ce48e
d563836
 
 
 
 
 
6176ef8
 
 
e39af65
6176ef8
 
 
ccf126e
bb9c09b
ccf126e
 
 
 
 
 
 
 
 
 
 
e39af65
ccf126e
 
 
b1a0d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c82104c
b1a0d53
6f59e3c
aa09a05
8da738a
763f5b7
8da738a
763f5b7
544f914
736419a
721eaca
8da738a
3fcba4f
6176ef8
4e59324
b1a0d53
aa09a05
24390e2
4e59324
 
 
 
 
3218b1a
4e59324
24390e2
4e59324
 
 
 
 
 
 
 
 
 
 
aa09a05
4e59324
 
aa09a05
4e59324
 
 
3218b1a
4e59324
 
 
3218b1a
 
 
 
4e59324
 
 
 
763f5b7
721eaca
621b193
6176ef8
 
5282aca
7145ecb
1e09a50
f30d0ea
5282aca
1e09a50
b8df8bd
53e71ae
142304a
8effe15
53e71ae
b5d9907
8e84211
142304a
d3c40d6
763f5b7
53e71ae
5282aca
1e09a50
5282aca
15ecc3d
 
8effe15
3cb6c3b
b5d9907
8e84211
142304a
d3c40d6
763f5b7
a5316e5
736419a
4e59324
a5316e5
8954378
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import pandas as pd

# Initialize models
classification_model = pipeline("text-classification", model="models/text_classification_model", tokenizer="models/text_classification_model", top_k=5)
mask_model = pipeline("fill-mask", model="models/fill_mask_model", tokenizer="models/fill_mask_model", top_k=100)

# Load data
eunis_habitats = pd.read_excel('data/eunis_habitats.xlsx')
    
def return_habitat_image(habitat_label):
    floraveg_url = f"https://floraveg.eu/habitat/overview/{habitat_label}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/syntaxa/thumbs/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    image = gr.Image(value=image_url)
    return image

def return_species_image(species):
    species = species.capitalize()
    floraveg_url = f"https://floraveg.eu/taxon/overview/{species}"
    response = requests.get(floraveg_url)
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        img_tag = soup.find('img', src=lambda x: x and x.startswith("https://files.ibot.cas.cz/cevs/images/taxa/large/"))
        if img_tag:
            image_url = img_tag['src']
        else:
            image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    else:
        image_url = "https://www.salonlfc.com/wp-content/uploads/2018/01/image-not-found-scaled-1150x647.png"
    image_url = "https://www.commissionoceanindien.org/wp-content/uploads/2018/07/plantnet.jpg"  # While we don't have the rights
    image = gr.Image(value=image_url)
    return image

def gbif_normalization(text):
    base = "https://api.gbif.org/v1"
    api = "species"
    function = "match"
    parameter = "name"
    url = f"{base}/{api}/{function}?{parameter}="
    all_species = text.split(',')
    all_species = [species.strip() for species in all_species]
    species_gbif = []
    for species in all_species:
        url = url.replace(url.partition('name')[2], f'={species}')
        r = requests.get(url)
        r = r.json()
        if 'species' in r:
            r = r["species"]
        else:
            r = species
        species_gbif.append(r)
    text = ", ".join(species_gbif)
    text = text.lower()
    return text

def classification(text, k):
    text = gbif_normalization(text)
    result = classification_model(text)
    habitat_labels = [res['label'] for res in result[0][:k]]
    if k == 1:
        text = f"This vegetation plot probably belongs to the habitat {habitat_labels[0]}."
    else:
        text = f"This vegetation plot probably belongs to the habitat {', '.join(habitat_labels[:-1])} or {habitat_labels[-1]}."
    habitat_name = eunis_habitats[eunis_habitats['EUNIS 2020 code'] == habitat_labels[0]]['EUNIS-2021 habitat name'].values[0]
    text += f"\nThe most likely habitat is '{habitat_name}'."
    text += f"\nSee an image of this habitat (i.e., {habitat_labels[0]}) below."
    image_output = return_habitat_image(habitat_labels[0])
    return text, image_output

def masking(text, k):
    text = gbif_normalization(text)
    text_split = text.split(', ')
    
    best_predictions = []
    
    for _ in range(k):
        max_score = 0
        best_prediction = None
        best_position = None
        best_sentence = None

        for i in range(len(text_split) + 1):
            masked_text = ', '.join(text_split[:i] + ['[MASK]'] + text_split[i:])
            
            j = 0
            while True:
                prediction = mask_model(masked_text)[j]
                species = prediction['token_str']
                if species in text_split or species in best_predictions:
                    j += 1
                else:
                    break

            score = prediction['score']
            sentence = prediction['sequence']

            if score > max_score:
                max_score = score
                best_prediction = species
                best_position = i
                best_sentence = sentence
        
        best_predictions.append(best_prediction)
        text_split.insert(best_position, best_prediction)
        
    best_positions = [text_split.index(prediction) for prediction in best_predictions]
    
    if k == 1:
        text = f"The most likely missing species is {best_predictions[0]} (position {best_positions[0]})."
    else:
        text = f"The most likely missing species are {', '.join(best_predictions[:-1])} and {best_predictions[-1]} (positions {', '.join(map(str, best_positions[:-1]))} and {best_positions[-1]})."
    text += f"\nThe completed vegetation plot is '{best_sentence}'."
    text += f"\nSee an image of the most likely species (i.e., {best_predictions[0]}) below."
    image = return_species_image(best_predictions[0])
    return text, image

with gr.Blocks() as demo:

    gr.Markdown("""<h1 style="text-align: center;">Pl@ntBERT</h1>""")
    
    with gr.Tab("Vegetation plot classification"):
        gr.Markdown("""<h3 style="text-align: center;">Classification of vegetation plots!</h3>""")
        with gr.Row():
            with gr.Column():
                species_classification = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_classification = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of habitats to display.")
            with gr.Column():
                text_classification = gr.Textbox(label="Prediction")
                image_classification = gr.Image()
        button_classification = gr.Button("Classify")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["phragmites australis, lemna minor, typha latifolia", 3]], [species_classification, k_classification], [text_classification, image_classification], classification, True)
        
    with gr.Tab("Missing species finding"):
        gr.Markdown("""<h3 style="text-align: center;">Finding the missing species!</h3>""")
        with gr.Row():
            with gr.Column():
                species_masking = gr.Textbox(lines=2, label="Species", placeholder="Enter a list of comma-separated binomial names here.")
                k_masking = gr.Slider(1, 5, value=1, step=1, label="Top-k", info="Choose the number of missing species to find.")
            with gr.Column():
                text_masking = gr.Textbox(label="Prediction")
                image_masking = gr.Image()
        button_masking = gr.Button("Find")
        gr.Markdown("""<h5 style="text-align: center;">An example of input</h5>""")
        gr.Examples([["calamagrostis arenaria, medicago marina, pancratium maritimum, thinopyrum junceum", 1]], [species_masking, k_masking], [text_masking, image_masking], masking, True)

    button_classification.click(classification, inputs=[species_classification, k_classification], outputs=[text_classification, image_classification])
    button_masking.click(masking, inputs=[species_masking, k_masking], outputs=[text_masking, image_masking])

demo.launch()