Fer Aguirre commited on
Commit
998cded
1 Parent(s): b0f265a

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +86 -0
  2. foia_sample.csv +0 -0
  3. requirements.txt +129 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from datasets import Dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ from sentence_transformers.util import semantic_search
6
+ import torch
7
+
8
+ model = SentenceTransformer("sentence-transformers/gtr-t5-large")
9
+
10
+
11
+ # Read files
12
+ url = "https://gist.githubusercontent.com/fer-aguirre/b6bdcf59ecae41f84765f72114de9fd1/raw/b4e029fe236c1f38275621686429b2c7aaa3d18b/embeddings.csv"
13
+
14
+ df_emb = pd.read_csv(url, index_col=0)
15
+
16
+ df = pd.read_csv('./foia_sample.csv')
17
+
18
+ dataset = Dataset.from_pandas(df_emb)
19
+
20
+ dataset_embeddings = torch.from_numpy(dataset.to_pandas().to_numpy()).to(torch.float)
21
+
22
+ st.markdown("**Inserta una solicitud de información para generar recomendaciones de dependencias**")
23
+
24
+ if request := st.text_area("", value=""):
25
+
26
+ output = model.encode(request)
27
+
28
+ query_embeddings = torch.FloatTensor(output)
29
+
30
+ hits = semantic_search(query_embeddings, dataset_embeddings, top_k=3)
31
+
32
+ id1 = hits[0][0]['corpus_id']
33
+ id2 = hits[0][1]['corpus_id']
34
+ id3 = hits[0][2]['corpus_id']
35
+
36
+ rec1 = df.iloc[id1].str.split(pat="/")[0]
37
+ rec2 = df.iloc[id2].str.split(pat="/")[0]
38
+ rec3 = df.iloc[id3].str.split(pat="/")[0]
39
+
40
+ list_rec = [rec1, rec2, rec3]
41
+ unique_list = []
42
+ for string in list_rec:
43
+ if string not in unique_list:
44
+ unique_list.append(string)
45
+ st.markdown(f'Recomendaciones:')
46
+ for rec in unique_list:
47
+ st.markdown(f':green[{rec[0]}]')
48
+
49
+ st.markdown("""---""")
50
+
51
+ if st.button('Genera un ejemplo random'):
52
+
53
+ test_example = df['combined'].sample(n=1)
54
+ index = test_example.index
55
+ idx = index[0]
56
+
57
+ original = df.iloc[idx].str.split(pat="/")[0]
58
+
59
+ request = test_example.to_string(index=False)
60
+
61
+ st.text(f'{idx}, {request}')
62
+
63
+ output = model.encode(request)
64
+
65
+ query_embeddings = torch.FloatTensor(output)
66
+
67
+ hits = semantic_search(query_embeddings, dataset_embeddings, top_k=3)
68
+
69
+ id1 = hits[0][0]['corpus_id']
70
+ id2 = hits[0][1]['corpus_id']
71
+ id3 = hits[0][2]['corpus_id']
72
+
73
+ rec1 = df.iloc[id1].str.split(pat="/")[0]
74
+ rec2 = df.iloc[id2].str.split(pat="/")[0]
75
+ rec3 = df.iloc[id3].str.split(pat="/")[0]
76
+
77
+ list_rec = [rec1, rec2, rec3]
78
+ unique_list = []
79
+ for string in list_rec:
80
+ if string not in unique_list:
81
+ unique_list.append(string)
82
+ st.markdown(f'Recomendaciones:')
83
+ for rec in unique_list:
84
+ st.markdown(f':green[{rec[0]}]')
85
+ st.markdown(f'Dependencia original:')
86
+ st.markdown(f':red[{original[0]}]')
foia_sample.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -e git+ssh://git@github.com/fer-aguirre/ai4foia.git@3469e89044d7f0ccfb440fb5762fd7cbd893fa82#egg=AI4FOIA
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ asttokens==2.2.1
6
+ async-timeout==4.0.2
7
+ attrs==22.2.0
8
+ backcall==0.2.0
9
+ backports.zoneinfo==0.2.1
10
+ blinker==1.5
11
+ cachetools==5.3.0
12
+ certifi==2022.12.7
13
+ charset-normalizer==3.1.0
14
+ click==8.1.3
15
+ cmake==3.26.1
16
+ comm==0.1.3
17
+ datasets==2.11.0
18
+ debugpy==1.6.6
19
+ decorator==5.1.1
20
+ dill==0.3.6
21
+ distlib==0.3.6
22
+ entrypoints==0.4
23
+ executing==1.2.0
24
+ fastjsonschema==2.16.3
25
+ filelock==3.10.7
26
+ frozenlist==1.3.3
27
+ fsspec==2023.3.0
28
+ gitdb==4.0.10
29
+ GitPython==3.1.31
30
+ huggingface-hub==0.13.3
31
+ idna==3.4
32
+ importlib-metadata==6.1.0
33
+ importlib-resources==5.12.0
34
+ ipykernel==6.22.0
35
+ ipython==8.12.0
36
+ jedi==0.18.2
37
+ Jinja2==3.1.2
38
+ joblib==1.2.0
39
+ jsonschema==4.17.3
40
+ jupyter_client==8.1.0
41
+ jupyter_core==5.3.0
42
+ lit==16.0.0
43
+ markdown-it-py==2.2.0
44
+ MarkupSafe==2.1.2
45
+ matplotlib-inline==0.1.6
46
+ mdurl==0.1.2
47
+ mpmath==1.3.0
48
+ multidict==6.0.4
49
+ multiprocess==0.70.14
50
+ nbformat==5.8.0
51
+ nest-asyncio==1.5.6
52
+ networkx==3.0
53
+ nltk==3.8.1
54
+ numpy==1.24.2
55
+ nvidia-cublas-cu11==11.10.3.66
56
+ nvidia-cuda-cupti-cu11==11.7.101
57
+ nvidia-cuda-nvrtc-cu11==11.7.99
58
+ nvidia-cuda-runtime-cu11==11.7.99
59
+ nvidia-cudnn-cu11==8.5.0.96
60
+ nvidia-cufft-cu11==10.9.0.58
61
+ nvidia-curand-cu11==10.2.10.91
62
+ nvidia-cusolver-cu11==11.4.0.1
63
+ nvidia-cusparse-cu11==11.7.4.91
64
+ nvidia-nccl-cu11==2.14.3
65
+ nvidia-nvtx-cu11==11.7.91
66
+ packaging==23.0
67
+ pandas==1.5.3
68
+ parso==0.8.3
69
+ pathlib==1.0.1
70
+ pbr==5.11.1
71
+ pexpect==4.8.0
72
+ pickleshare==0.7.5
73
+ Pillow==9.5.0
74
+ pipenv==2023.3.20
75
+ pkgutil_resolve_name==1.3.10
76
+ platformdirs==3.2.0
77
+ prompt-toolkit==3.0.38
78
+ protobuf==3.20.3
79
+ psutil==5.9.4
80
+ ptyprocess==0.7.0
81
+ pure-eval==0.2.2
82
+ pyarrow==11.0.0
83
+ pydeck==0.8.0
84
+ Pygments==2.14.0
85
+ Pympler==1.0.1
86
+ pyprojroot==0.3.0
87
+ pyrsistent==0.19.3
88
+ python-dateutil==2.8.2
89
+ pytz==2023.3
90
+ pytz-deprecation-shim==0.1.0.post0
91
+ PyYAML==6.0
92
+ pyzmq==25.0.2
93
+ regex==2023.3.23
94
+ requests==2.28.2
95
+ responses==0.18.0
96
+ rich==13.3.3
97
+ scikit-learn==1.2.2
98
+ scipy==1.10.1
99
+ semver==2.13.0
100
+ sentence-transformers==2.2.2
101
+ sentencepiece==0.1.97
102
+ six==1.16.0
103
+ smmap==5.0.0
104
+ stack-data==0.6.2
105
+ streamlit==1.20.0
106
+ sympy==1.11.1
107
+ threadpoolctl==3.1.0
108
+ tokenizers==0.13.2
109
+ toml==0.10.2
110
+ toolz==0.12.0
111
+ torch==2.0.0
112
+ torchvision==0.15.1
113
+ tornado==6.2
114
+ tqdm==4.65.0
115
+ traitlets==5.9.0
116
+ transformers==4.27.4
117
+ triton==2.0.0
118
+ typing_extensions==4.5.0
119
+ tzdata==2023.3
120
+ tzlocal==4.3
121
+ urllib3==1.26.15
122
+ validators==0.20.0
123
+ virtualenv==20.21.0
124
+ virtualenv-clone==0.5.7
125
+ watchdog==3.0.0
126
+ wcwidth==0.2.6
127
+ xxhash==3.2.0
128
+ yarl==1.8.2
129
+ zipp==3.15.0