Spaces:
Running
Running
Merge pull request #1 from CCCBora/semantic-scholar
Browse files- .idea/.gitignore +10 -0
- .idea/auto-draft.iml +14 -0
- .idea/inspectionProfiles/Project_Default.xml +95 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- __pycache__/auto_backgrounds.cpython-310.pyc +0 -0
- __pycache__/auto_draft.cpython-310.pyc +0 -0
- __pycache__/section_generator.cpython-310.pyc +0 -0
- app.py +35 -24
- auto_backgrounds.py +13 -5
- section_generator.py +2 -2
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/figures.cpython-310.pyc +0 -0
- utils/__pycache__/file_operations.cpython-310.pyc +0 -0
- utils/__pycache__/gpt_interaction.cpython-310.pyc +0 -0
- utils/__pycache__/prompts.cpython-310.pyc +0 -0
- utils/__pycache__/references.cpython-310.pyc +0 -0
- utils/__pycache__/storage.cpython-310.pyc +0 -0
- utils/__pycache__/tex_processing.cpython-310.pyc +0 -0
- utils/prompts.py +5 -0
- utils/references.py +131 -16
- utils/tex_processing.py +3 -1
.idea/.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Datasource local storage ignored files
|
5 |
+
/dataSources/
|
6 |
+
/dataSources.local.xml
|
7 |
+
# Editor-based HTTP Client requests
|
8 |
+
/httpRequests/
|
9 |
+
**/__pycache__
|
10 |
+
**/.idea
|
.idea/auto-draft.iml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$">
|
5 |
+
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
6 |
+
</content>
|
7 |
+
<orderEntry type="inheritedJdk" />
|
8 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
9 |
+
</component>
|
10 |
+
<component name="PyDocumentationSettings">
|
11 |
+
<option name="format" value="PLAIN" />
|
12 |
+
<option name="myDocStringFormat" value="Plain" />
|
13 |
+
</component>
|
14 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyChainedComparisonsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoreConstantInTheMiddle" value="true" />
|
6 |
+
</inspection_tool>
|
7 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
8 |
+
<option name="ignoredPackages">
|
9 |
+
<value>
|
10 |
+
<list size="69">
|
11 |
+
<item index="0" class="java.lang.String" itemvalue="pprint" />
|
12 |
+
<item index="1" class="java.lang.String" itemvalue="tnt" />
|
13 |
+
<item index="2" class="java.lang.String" itemvalue="pyglet" />
|
14 |
+
<item index="3" class="java.lang.String" itemvalue="pyzmq" />
|
15 |
+
<item index="4" class="java.lang.String" itemvalue="gym" />
|
16 |
+
<item index="5" class="java.lang.String" itemvalue="torch" />
|
17 |
+
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
18 |
+
<item index="7" class="java.lang.String" itemvalue="absl-py" />
|
19 |
+
<item index="8" class="java.lang.String" itemvalue="numba" />
|
20 |
+
<item index="9" class="java.lang.String" itemvalue="protobuf" />
|
21 |
+
<item index="10" class="java.lang.String" itemvalue="torch-scatter" />
|
22 |
+
<item index="11" class="java.lang.String" itemvalue="joblib" />
|
23 |
+
<item index="12" class="java.lang.String" itemvalue="threadpoolctl" />
|
24 |
+
<item index="13" class="java.lang.String" itemvalue="scikit-learn" />
|
25 |
+
<item index="14" class="java.lang.String" itemvalue="PyYAML" />
|
26 |
+
<item index="15" class="java.lang.String" itemvalue="python-dateutil" />
|
27 |
+
<item index="16" class="java.lang.String" itemvalue="cycler" />
|
28 |
+
<item index="17" class="java.lang.String" itemvalue="MarkupSafe" />
|
29 |
+
<item index="18" class="java.lang.String" itemvalue="mpi4py" />
|
30 |
+
<item index="19" class="java.lang.String" itemvalue="torchvision" />
|
31 |
+
<item index="20" class="java.lang.String" itemvalue="line-profiler" />
|
32 |
+
<item index="21" class="java.lang.String" itemvalue="pyasn1-modules" />
|
33 |
+
<item index="22" class="java.lang.String" itemvalue="certifi" />
|
34 |
+
<item index="23" class="java.lang.String" itemvalue="oauthlib" />
|
35 |
+
<item index="24" class="java.lang.String" itemvalue="pyparsing" />
|
36 |
+
<item index="25" class="java.lang.String" itemvalue="Markdown" />
|
37 |
+
<item index="26" class="java.lang.String" itemvalue="Werkzeug" />
|
38 |
+
<item index="27" class="java.lang.String" itemvalue="h5py" />
|
39 |
+
<item index="28" class="java.lang.String" itemvalue="rdflib" />
|
40 |
+
<item index="29" class="java.lang.String" itemvalue="torch-cluster" />
|
41 |
+
<item index="30" class="java.lang.String" itemvalue="kiwisolver" />
|
42 |
+
<item index="31" class="java.lang.String" itemvalue="pytorch-lightning" />
|
43 |
+
<item index="32" class="java.lang.String" itemvalue="tensorboard" />
|
44 |
+
<item index="33" class="java.lang.String" itemvalue="imageio" />
|
45 |
+
<item index="34" class="java.lang.String" itemvalue="matplotlib" />
|
46 |
+
<item index="35" class="java.lang.String" itemvalue="test-tube" />
|
47 |
+
<item index="36" class="java.lang.String" itemvalue="googledrivedownloader" />
|
48 |
+
<item index="37" class="java.lang.String" itemvalue="idna" />
|
49 |
+
<item index="38" class="java.lang.String" itemvalue="rsa" />
|
50 |
+
<item index="39" class="java.lang.String" itemvalue="networkx" />
|
51 |
+
<item index="40" class="java.lang.String" itemvalue="isodate" />
|
52 |
+
<item index="41" class="java.lang.String" itemvalue="torch-sparse" />
|
53 |
+
<item index="42" class="java.lang.String" itemvalue="llvmlite" />
|
54 |
+
<item index="43" class="java.lang.String" itemvalue="pyasn1" />
|
55 |
+
<item index="44" class="java.lang.String" itemvalue="requests" />
|
56 |
+
<item index="45" class="java.lang.String" itemvalue="importlib-metadata" />
|
57 |
+
<item index="46" class="java.lang.String" itemvalue="Jinja2" />
|
58 |
+
<item index="47" class="java.lang.String" itemvalue="requests-oauthlib" />
|
59 |
+
<item index="48" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
|
60 |
+
<item index="49" class="java.lang.String" itemvalue="zipp" />
|
61 |
+
<item index="50" class="java.lang.String" itemvalue="urllib3" />
|
62 |
+
<item index="51" class="java.lang.String" itemvalue="torch-geometric" />
|
63 |
+
<item index="52" class="java.lang.String" itemvalue="scipy" />
|
64 |
+
<item index="53" class="java.lang.String" itemvalue="six" />
|
65 |
+
<item index="54" class="java.lang.String" itemvalue="google-auth-oauthlib" />
|
66 |
+
<item index="55" class="java.lang.String" itemvalue="chardet" />
|
67 |
+
<item index="56" class="java.lang.String" itemvalue="pandas" />
|
68 |
+
<item index="57" class="java.lang.String" itemvalue="tqdm" />
|
69 |
+
<item index="58" class="java.lang.String" itemvalue="torch-spline-conv" />
|
70 |
+
<item index="59" class="java.lang.String" itemvalue="ase" />
|
71 |
+
<item index="60" class="java.lang.String" itemvalue="future" />
|
72 |
+
<item index="61" class="java.lang.String" itemvalue="cachetools" />
|
73 |
+
<item index="62" class="java.lang.String" itemvalue="grpcio" />
|
74 |
+
<item index="63" class="java.lang.String" itemvalue="pytz" />
|
75 |
+
<item index="64" class="java.lang.String" itemvalue="google-auth" />
|
76 |
+
<item index="65" class="java.lang.String" itemvalue="Pillow" />
|
77 |
+
<item index="66" class="java.lang.String" itemvalue="decorator" />
|
78 |
+
<item index="67" class="java.lang.String" itemvalue="typing-extensions" />
|
79 |
+
<item index="68" class="java.lang.String" itemvalue="ale-py" />
|
80 |
+
</list>
|
81 |
+
</value>
|
82 |
+
</option>
|
83 |
+
</inspection_tool>
|
84 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
85 |
+
<option name="ignoredErrors">
|
86 |
+
<list>
|
87 |
+
<option value="N812" />
|
88 |
+
<option value="N802" />
|
89 |
+
<option value="N803" />
|
90 |
+
<option value="N806" />
|
91 |
+
</list>
|
92 |
+
</option>
|
93 |
+
</inspection_tool>
|
94 |
+
</profile>
|
95 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (auto-draft)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/auto-draft.iml" filepath="$PROJECT_DIR$/.idea/auto-draft.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
__pycache__/auto_backgrounds.cpython-310.pyc
ADDED
Binary file (4.06 kB). View file
|
|
__pycache__/auto_draft.cpython-310.pyc
ADDED
Binary file (4.56 kB). View file
|
|
__pycache__/section_generator.cpython-310.pyc
ADDED
Binary file (2.42 kB). View file
|
|
app.py
CHANGED
@@ -4,15 +4,20 @@ import openai
|
|
4 |
from auto_backgrounds import generate_backgrounds, fake_generator, generate_draft
|
5 |
from utils.file_operations import hash_name
|
6 |
|
|
|
7 |
# todo:
|
8 |
-
#
|
9 |
-
#
|
10 |
-
# 5. Use some simple method for simple tasks (including: writing abstract, conclusion, generate keywords, generate figures...)
|
11 |
# 5.1 Use GPT 3.5 for abstract, conclusion, ... (or may not)
|
12 |
# 5.2 Use local LLM to generate keywords, figures, ...
|
13 |
# 5.3 Use embedding to find most related papers (find a paper dataset)
|
14 |
-
# 5.4 Use Semantic Scholar API instead of Arxiv API.
|
15 |
# 6. get logs when the procedure is not completed.
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
openai_key = os.getenv("OPENAI_API_KEY")
|
18 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
@@ -35,14 +40,13 @@ else:
|
|
35 |
IS_OPENAI_API_KEY_AVAILABLE = False
|
36 |
|
37 |
|
38 |
-
|
39 |
def clear_inputs(text1, text2):
|
40 |
return "", ""
|
41 |
|
42 |
|
43 |
-
def wrapped_generator(
|
44 |
-
template
|
45 |
-
cache_mode
|
46 |
# if `cache_mode` is True, then follow the following steps:
|
47 |
# check if "title"+"description" have been generated before
|
48 |
# if so, download from the cloud storage, return it
|
@@ -52,15 +56,16 @@ def wrapped_generator(title, description, openai_key = None,
|
|
52 |
# generator = generate_backgrounds
|
53 |
generator = generate_draft
|
54 |
# generator = fake_generator
|
55 |
-
if
|
56 |
-
openai.api_key =
|
57 |
openai.Model.list()
|
58 |
|
59 |
if cache_mode:
|
60 |
from utils.storage import list_all_files, download_file, upload_file
|
61 |
# check if "title"+"description" have been generated before
|
62 |
|
63 |
-
input_dict = {"title":
|
|
|
64 |
file_name = hash_name(input_dict) + ".zip"
|
65 |
file_list = list_all_files()
|
66 |
# print(f"{file_name} will be generated. Check the file list {file_list}")
|
@@ -70,21 +75,23 @@ def wrapped_generator(title, description, openai_key = None,
|
|
70 |
return file_name
|
71 |
else:
|
72 |
# generate the result.
|
73 |
-
# output = fake_generate_backgrounds(title, description, openai_key)
|
74 |
-
|
|
|
75 |
upload_file(output)
|
76 |
return output
|
77 |
else:
|
78 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
79 |
-
output = generator(
|
80 |
return output
|
81 |
|
82 |
|
83 |
-
theme = gr.themes.
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
|
89 |
with gr.Blocks(theme=theme) as demo:
|
90 |
gr.Markdown('''
|
@@ -102,16 +109,20 @@ with gr.Blocks(theme=theme) as demo:
|
|
102 |
''')
|
103 |
with gr.Row():
|
104 |
with gr.Column(scale=2):
|
105 |
-
key =
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
108 |
description = gr.Textbox(lines=5, label="Description (Optional)", visible=False)
|
109 |
|
110 |
with gr.Row():
|
111 |
clear_button = gr.Button("Clear")
|
112 |
-
submit_button = gr.Button("Submit")
|
113 |
with gr.Column(scale=1):
|
114 |
-
style_mapping = {True: "color:white;background-color:green",
|
|
|
115 |
availability_mapping = {True: "AVAILABLE", False: "NOT AVAILABLE"}
|
116 |
gr.Markdown(f'''## Huggingface Space Status
|
117 |
当`OpenAI API`显示AVAILABLE的时候这个Space可以直接使用.
|
|
|
4 |
from auto_backgrounds import generate_backgrounds, fake_generator, generate_draft
|
5 |
from utils.file_operations import hash_name
|
6 |
|
7 |
+
# note: App白屏bug:允许第三方cookie
|
8 |
# todo:
|
9 |
+
# 5. Use some simple method for simple tasks
|
10 |
+
# (including: writing abstract, conclusion, generate keywords, generate figures...)
|
|
|
11 |
# 5.1 Use GPT 3.5 for abstract, conclusion, ... (or may not)
|
12 |
# 5.2 Use local LLM to generate keywords, figures, ...
|
13 |
# 5.3 Use embedding to find most related papers (find a paper dataset)
|
|
|
14 |
# 6. get logs when the procedure is not completed.
|
15 |
+
# 7. 自己的文件库; 更多的prompts
|
16 |
+
# 11. distinguish citep and citet
|
17 |
+
# future:
|
18 |
+
# 8. Change prompts to langchain
|
19 |
+
# 4. add auto_polishing function
|
20 |
+
# 12. Change link to more appealing color # after the website is built;
|
21 |
|
22 |
openai_key = os.getenv("OPENAI_API_KEY")
|
23 |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
|
|
40 |
IS_OPENAI_API_KEY_AVAILABLE = False
|
41 |
|
42 |
|
|
|
43 |
def clear_inputs(text1, text2):
|
44 |
return "", ""
|
45 |
|
46 |
|
47 |
+
def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
48 |
+
template="ICLR2022",
|
49 |
+
cache_mode=IS_CACHE_AVAILABLE, generator=None):
|
50 |
# if `cache_mode` is True, then follow the following steps:
|
51 |
# check if "title"+"description" have been generated before
|
52 |
# if so, download from the cloud storage, return it
|
|
|
56 |
# generator = generate_backgrounds
|
57 |
generator = generate_draft
|
58 |
# generator = fake_generator
|
59 |
+
if openai_api_key is not None:
|
60 |
+
openai.api_key = openai_api_key
|
61 |
openai.Model.list()
|
62 |
|
63 |
if cache_mode:
|
64 |
from utils.storage import list_all_files, download_file, upload_file
|
65 |
# check if "title"+"description" have been generated before
|
66 |
|
67 |
+
input_dict = {"title": paper_title, "description": paper_description,
|
68 |
+
"generator": "generate_draft"} # todo: modify here also
|
69 |
file_name = hash_name(input_dict) + ".zip"
|
70 |
file_list = list_all_files()
|
71 |
# print(f"{file_name} will be generated. Check the file list {file_list}")
|
|
|
75 |
return file_name
|
76 |
else:
|
77 |
# generate the result.
|
78 |
+
# output = fake_generate_backgrounds(title, description, openai_key)
|
79 |
+
# todo: use `generator` to control which function to use.
|
80 |
+
output = generator(paper_title, paper_description, template, "gpt-4")
|
81 |
upload_file(output)
|
82 |
return output
|
83 |
else:
|
84 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
85 |
+
output = generator(paper_title, paper_description, template, "gpt-4")
|
86 |
return output
|
87 |
|
88 |
|
89 |
+
theme = gr.themes.Default(font=gr.themes.GoogleFont("Questrial"))
|
90 |
+
# .set(
|
91 |
+
# background_fill_primary='#E5E4E2',
|
92 |
+
# background_fill_secondary = '#F6F6F6',
|
93 |
+
# button_primary_background_fill="#281A39"
|
94 |
+
# )
|
95 |
|
96 |
with gr.Blocks(theme=theme) as demo:
|
97 |
gr.Markdown('''
|
|
|
109 |
''')
|
110 |
with gr.Row():
|
111 |
with gr.Column(scale=2):
|
112 |
+
key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
|
113 |
+
visible=not IS_OPENAI_API_KEY_AVAILABLE)
|
114 |
+
# generator = gr.Dropdown(choices=["学术论文", "文献总结"], value="文献总结",
|
115 |
+
# label="Selection", info="目前支持生成'学术论文'和'文献总结'.", interactive=True)
|
116 |
+
title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
|
117 |
+
label="Title", info="论文标题")
|
118 |
description = gr.Textbox(lines=5, label="Description (Optional)", visible=False)
|
119 |
|
120 |
with gr.Row():
|
121 |
clear_button = gr.Button("Clear")
|
122 |
+
submit_button = gr.Button("Submit", variant="primary")
|
123 |
with gr.Column(scale=1):
|
124 |
+
style_mapping = {True: "color:white;background-color:green",
|
125 |
+
False: "color:white;background-color:red"} # todo: to match website's style
|
126 |
availability_mapping = {True: "AVAILABLE", False: "NOT AVAILABLE"}
|
127 |
gr.Markdown(f'''## Huggingface Space Status
|
128 |
当`OpenAI API`显示AVAILABLE的时候这个Space可以直接使用.
|
auto_backgrounds.py
CHANGED
@@ -30,7 +30,8 @@ def log_usage(usage, generating_target, print_out=True):
|
|
30 |
print(message)
|
31 |
logging.info(message)
|
32 |
|
33 |
-
def _generation_setup(title, description="", template="ICLR2022", model="gpt-4"
|
|
|
34 |
'''
|
35 |
todo: use `model` to control which model to use; may use another method to generate keywords or collect references
|
36 |
'''
|
@@ -44,12 +45,12 @@ def _generation_setup(title, description="", template="ICLR2022", model="gpt-4")
|
|
44 |
# Generate keywords and references
|
45 |
print("Initialize the paper information ...")
|
46 |
input_dict = {"title": title, "description": description}
|
47 |
-
keywords, usage = keywords_generation(input_dict, model="gpt-3.5-turbo")
|
48 |
print(f"keywords: {keywords}")
|
49 |
log_usage(usage, "keywords")
|
50 |
|
51 |
ref = References(load_papers="")
|
52 |
-
ref.collect_papers(keywords, method=
|
53 |
all_paper_ids = ref.to_bibtex(bibtex_path) # todo: this will used to check if all citations are in this list
|
54 |
|
55 |
print(f"The paper information has been initialized. References are saved to {bibtex_path}.")
|
@@ -90,8 +91,8 @@ def fake_generator(title, description="", template="ICLR2022", model="gpt-4"):
|
|
90 |
return make_archive("sample-output.pdf", filename)
|
91 |
|
92 |
|
93 |
-
def generate_draft(title, description="", template="ICLR2022", model="gpt-4"):
|
94 |
-
paper, destination_folder, _ = _generation_setup(title, description, template, model)
|
95 |
|
96 |
# todo: `list_of_methods` failed to be generated; find a solution ...
|
97 |
# print("Generating figures ...")
|
@@ -125,3 +126,10 @@ def generate_draft(title, description="", template="ICLR2022", model="gpt-4"):
|
|
125 |
input_dict = {"title": title, "description": description, "generator": "generate_draft"}
|
126 |
filename = hash_name(input_dict) + ".zip"
|
127 |
return make_archive(destination_folder, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
print(message)
|
31 |
logging.info(message)
|
32 |
|
33 |
+
def _generation_setup(title, description="", template="ICLR2022", model="gpt-4",
|
34 |
+
search_engine="ss", tldr=False, max_kw_refs=10):
|
35 |
'''
|
36 |
todo: use `model` to control which model to use; may use another method to generate keywords or collect references
|
37 |
'''
|
|
|
45 |
# Generate keywords and references
|
46 |
print("Initialize the paper information ...")
|
47 |
input_dict = {"title": title, "description": description}
|
48 |
+
keywords, usage = keywords_generation(input_dict, model="gpt-3.5-turbo", max_kw_refs=max_kw_refs)
|
49 |
print(f"keywords: {keywords}")
|
50 |
log_usage(usage, "keywords")
|
51 |
|
52 |
ref = References(load_papers="")
|
53 |
+
ref.collect_papers(keywords, method=search_engine, tldr=tldr)
|
54 |
all_paper_ids = ref.to_bibtex(bibtex_path) # todo: this will used to check if all citations are in this list
|
55 |
|
56 |
print(f"The paper information has been initialized. References are saved to {bibtex_path}.")
|
|
|
91 |
return make_archive("sample-output.pdf", filename)
|
92 |
|
93 |
|
94 |
+
def generate_draft(title, description="", template="ICLR2022", model="gpt-4", search_engine="ss", tldr=True, max_kw_refs=14):
|
95 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, model, search_engine, tldr, max_kw_refs)
|
96 |
|
97 |
# todo: `list_of_methods` failed to be generated; find a solution ...
|
98 |
# print("Generating figures ...")
|
|
|
126 |
input_dict = {"title": title, "description": description, "generator": "generate_draft"}
|
127 |
filename = hash_name(input_dict) + ".zip"
|
128 |
return make_archive(destination_folder, filename)
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
133 |
+
description = ""
|
134 |
+
output = generate_draft(title, description, search_engine="ss", tldr=True, max_kw_refs=10)
|
135 |
+
print(output)
|
section_generator.py
CHANGED
@@ -76,11 +76,11 @@ def section_generation(paper, section, save_to_path, model):
|
|
76 |
print(f"{section} has been generated. Saved to {tex_file}.")
|
77 |
return usage
|
78 |
|
79 |
-
def keywords_generation(input_dict, model):
|
80 |
title = input_dict.get("title")
|
81 |
description = input_dict.get("description", "")
|
82 |
if title is not None:
|
83 |
-
prompts = generate_keywords_prompts(title, description)
|
84 |
gpt_response, usage = get_responses(prompts, model)
|
85 |
keywords = extract_keywords(gpt_response)
|
86 |
return keywords, usage
|
|
|
76 |
print(f"{section} has been generated. Saved to {tex_file}.")
|
77 |
return usage
|
78 |
|
79 |
+
def keywords_generation(input_dict, model, max_kw_refs = 10):
|
80 |
title = input_dict.get("title")
|
81 |
description = input_dict.get("description", "")
|
82 |
if title is not None:
|
83 |
+
prompts = generate_keywords_prompts(title, description, max_kw_refs)
|
84 |
gpt_response, usage = get_responses(prompts, model)
|
85 |
keywords = extract_keywords(gpt_response)
|
86 |
return keywords, usage
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (137 Bytes). View file
|
|
utils/__pycache__/figures.cpython-310.pyc
ADDED
Binary file (1.89 kB). View file
|
|
utils/__pycache__/file_operations.cpython-310.pyc
ADDED
Binary file (1.41 kB). View file
|
|
utils/__pycache__/gpt_interaction.cpython-310.pyc
ADDED
Binary file (2.79 kB). View file
|
|
utils/__pycache__/prompts.cpython-310.pyc
ADDED
Binary file (6.66 kB). View file
|
|
utils/__pycache__/references.cpython-310.pyc
ADDED
Binary file (6.77 kB). View file
|
|
utils/__pycache__/storage.cpython-310.pyc
ADDED
Binary file (1.71 kB). View file
|
|
utils/__pycache__/tex_processing.cpython-310.pyc
ADDED
Binary file (609 Bytes). View file
|
|
utils/prompts.py
CHANGED
@@ -10,6 +10,11 @@ INSTRUCTIONS = {"introduction": "Please include five paragraph: Establishing the
|
|
10 |
"conclusion": "Please read the paper I have written and write the conclusion section.",
|
11 |
"abstract": "Please read the paper I have written and write the abstract."}
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
BG_INSTRUCTIONS = {"introduction": "Please include four paragraph: Establishing the motivation for this survey. Explaining its importance and relevance to the AI community. Clearly state the coverage of this survey and the specific research questions or objectives. Briefly mention key related work for context. ",
|
14 |
"related works": r"Please discuss key publications, methods, and techniques in related research area. Analyze the strengths and weaknesses of existing methods, and present the related works in a logical manner, often chronologically. Consider using a taxonomy or categorization to structure the discussion. Do not use \section{...} or \subsection{...}; use \paragraph{...} instead. ",
|
15 |
"backgrounds": r"Please clearly state the central problem in this field. Explain the foundational theories, concepts, and principles that underpin your research using as many as mathematical formulas or equations (written in LaTeX). Introduce any necessary mathematical notations, equations, or algorithms that are central to this field (written them in LaTeX). Do not include \section{...} but you can have \subsection{...}. ",}
|
|
|
10 |
"conclusion": "Please read the paper I have written and write the conclusion section.",
|
11 |
"abstract": "Please read the paper I have written and write the abstract."}
|
12 |
|
13 |
+
INSTRUCTIONS["related works"] = r"Please discuss three to five main related fields to this paper. For each field, select " \
|
14 |
+
r"five to ten key publications from references. For each reference, analyze its strengths and weaknesses in one or two sentences. " \
|
15 |
+
r"Do not use \section{...} or \subsection{...}; use \paragraph{...} to list related fields. "
|
16 |
+
|
17 |
+
|
18 |
BG_INSTRUCTIONS = {"introduction": "Please include four paragraph: Establishing the motivation for this survey. Explaining its importance and relevance to the AI community. Clearly state the coverage of this survey and the specific research questions or objectives. Briefly mention key related work for context. ",
|
19 |
"related works": r"Please discuss key publications, methods, and techniques in related research area. Analyze the strengths and weaknesses of existing methods, and present the related works in a logical manner, often chronologically. Consider using a taxonomy or categorization to structure the discussion. Do not use \section{...} or \subsection{...}; use \paragraph{...} instead. ",
|
20 |
"backgrounds": r"Please clearly state the central problem in this field. Explain the foundational theories, concepts, and principles that underpin your research using as many as mathematical formulas or equations (written in LaTeX). Introduce any necessary mathematical notations, equations, or algorithms that are central to this field (written them in LaTeX). Do not include \section{...} but you can have \subsection{...}. ",}
|
utils/references.py
CHANGED
@@ -8,10 +8,115 @@
|
|
8 |
import requests
|
9 |
import re
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Build the arXiv API query URL with the given keyword and other parameters
|
16 |
def build_query_url(keyword, results_limit=3, sort_by="relevance", sort_order="descending"):
|
17 |
base_url = "http://export.arxiv.org/api/query?"
|
@@ -37,6 +142,7 @@ def _collect_papers_arxiv(keyword, counts=3):
|
|
37 |
title = entry.find(f"{namespace}title").text
|
38 |
link = entry.find(f"{namespace}id").text
|
39 |
summary = entry.find(f"{namespace}summary").text
|
|
|
40 |
|
41 |
# Extract the authors
|
42 |
authors = entry.findall(f"{namespace}author")
|
@@ -76,9 +182,14 @@ def _collect_papers_arxiv(keyword, counts=3):
|
|
76 |
results = parse_results(content)
|
77 |
return results
|
78 |
|
|
|
|
|
|
|
|
|
|
|
79 |
# Each `paper` is a dictionary containing (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
|
80 |
class References:
|
81 |
-
def __init__(self, load_papers
|
82 |
if load_papers:
|
83 |
# todo: read a json file from the given path
|
84 |
# this could be used to support pre-defined references
|
@@ -86,7 +197,7 @@ class References:
|
|
86 |
else:
|
87 |
self.papers = []
|
88 |
|
89 |
-
def collect_papers(self, keywords_dict, method="arxiv"):
|
90 |
"""
|
91 |
keywords_dict:
|
92 |
{"machine learning": 5, "language model": 2};
|
@@ -94,11 +205,13 @@ class References:
|
|
94 |
"""
|
95 |
match method:
|
96 |
case "arxiv":
|
97 |
-
process =_collect_papers_arxiv
|
|
|
|
|
98 |
case _:
|
99 |
raise NotImplementedError("Other sources have not been not supported yet.")
|
100 |
for key, counts in keywords_dict.items():
|
101 |
-
self.papers = self.papers + process(key, counts)
|
102 |
|
103 |
seen = set()
|
104 |
papers = []
|
@@ -146,15 +259,17 @@ class References:
|
|
146 |
prompts[paper["paper_id"]] = paper["abstract"]
|
147 |
return prompts
|
148 |
|
|
|
149 |
if __name__ == "__main__":
|
150 |
refs = References()
|
151 |
keywords_dict = {
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
}
|
158 |
-
refs.collect_papers(keywords_dict)
|
159 |
for p in refs.papers:
|
160 |
-
print(p["paper_id"])
|
|
|
|
8 |
import requests
|
9 |
import re
|
10 |
|
11 |
+
|
12 |
+
#########################################################
|
13 |
+
# Some basic tools
|
14 |
+
#########################################################
|
15 |
+
def remove_newlines(serie):
|
16 |
+
serie = serie.replace('\n', ' ')
|
17 |
+
serie = serie.replace('\\n', ' ')
|
18 |
+
serie = serie.replace(' ', ' ')
|
19 |
+
serie = serie.replace(' ', ' ')
|
20 |
+
return serie
|
21 |
+
|
22 |
+
|
23 |
+
#########################################################
|
24 |
+
# Semantic Scholar (SS) API
|
25 |
+
#########################################################
|
26 |
+
def ss_search(keywords, limit=20, fields=None):
|
27 |
+
# space between the query to be removed and replaced with +
|
28 |
+
if fields is None:
|
29 |
+
fields = ["title", "abstract", "venue", "year", "authors", "tldr", "embedding", "externalIds"]
|
30 |
+
keywords = keywords.lower()
|
31 |
+
keywords = keywords.replace(" ", "+")
|
32 |
+
url = f'https://api.semanticscholar.org/graph/v1/paper/search?query={keywords}&limit={limit}&fields={",".join(fields)}'
|
33 |
+
# headers = {"Accept": "*/*", "x-api-key": constants.S2_KEY}
|
34 |
+
headers = {"Accept": "*/*"}
|
35 |
+
|
36 |
+
response = requests.get(url, headers=headers, timeout=30)
|
37 |
+
return response.json()
|
38 |
+
|
39 |
+
|
40 |
+
def _collect_papers_ss(keyword, counts=3, tldr=False):
|
41 |
+
def externalIds2link(externalIds):
|
42 |
+
# Sample externalIds:
|
43 |
+
# "{'MAG': '2932819148', 'DBLP': 'conf/icml/HaarnojaZAL18', 'ArXiv': '1801.01290', 'CorpusId': 28202810}"
|
44 |
+
if externalIds:
|
45 |
+
# Supports ArXiv, MAG, ACL, PubMed, Medline, PubMedCentral, DBLP, DOI
|
46 |
+
# priority: DBLP > arXiv > (todo: MAG > CorpusId > DOI > ACL > PubMed > Mdeline > PubMedCentral)
|
47 |
+
# DBLP
|
48 |
+
dblp_id = externalIds.get('DBLP')
|
49 |
+
if dblp_id is not None:
|
50 |
+
dblp_link = f"dblp.org/rec/{dblp_id}"
|
51 |
+
return dblp_link
|
52 |
+
# arXiv
|
53 |
+
arxiv_id = externalIds.get('ArXiv')
|
54 |
+
if arxiv_id is not None:
|
55 |
+
arxiv_link = f"arxiv.org/abs/{arxiv_id}"
|
56 |
+
return arxiv_link
|
57 |
+
return ""
|
58 |
+
else:
|
59 |
+
# if this is an empty dictionary, return an empty string
|
60 |
+
return ""
|
61 |
+
|
62 |
+
def extract_paper_id(last_name, year_str, title):
|
63 |
+
pattern = r'^\w+'
|
64 |
+
words = re.findall(pattern, title)
|
65 |
+
# return last_name + year_str + title.split(' ', 1)[0]
|
66 |
+
return last_name + year_str + words[0]
|
67 |
+
|
68 |
+
def extract_author_info(raw_authors):
|
69 |
+
authors = [author['name'] for author in raw_authors]
|
70 |
+
|
71 |
+
authors_str = " and ".join(authors)
|
72 |
+
last_name = authors[0].split()[-1]
|
73 |
+
return authors_str, last_name
|
74 |
+
|
75 |
+
def parse_search_results(search_results_ss):
|
76 |
+
# turn the search result to a list of paper dictionary.
|
77 |
+
papers = []
|
78 |
+
for raw_paper in search_results_ss:
|
79 |
+
if raw_paper["abstract"] is None:
|
80 |
+
continue
|
81 |
+
|
82 |
+
authors_str, last_name = extract_author_info(raw_paper['authors'])
|
83 |
+
year_str = str(raw_paper['year'])
|
84 |
+
title = raw_paper['title']
|
85 |
+
# some journal may contain &; replace it. e.g. journal={IEEE Power & Energy Society General Meeting}
|
86 |
+
journal = raw_paper['venue'].replace("&", "\\&")
|
87 |
+
if not journal:
|
88 |
+
journal = "arXiv preprint"
|
89 |
+
paper_id = extract_paper_id(last_name, year_str, title).lower()
|
90 |
+
link = externalIds2link(raw_paper['externalIds'])
|
91 |
+
if tldr and raw_paper['tldr'] is not None:
|
92 |
+
abstract = raw_paper['tldr']['text']
|
93 |
+
else:
|
94 |
+
abstract = remove_newlines(raw_paper['abstract'])
|
95 |
+
result = {
|
96 |
+
"paper_id": paper_id,
|
97 |
+
"title": title,
|
98 |
+
"abstract": abstract, # todo: compare results with tldr
|
99 |
+
"link": link,
|
100 |
+
"authors": authors_str,
|
101 |
+
"year": year_str,
|
102 |
+
"journal": journal
|
103 |
+
}
|
104 |
+
papers.append(result)
|
105 |
+
return papers
|
106 |
+
|
107 |
+
raw_results = ss_search(keyword, limit=counts)
|
108 |
+
if raw_results is not None:
|
109 |
+
search_results = raw_results['data']
|
110 |
+
else:
|
111 |
+
search_results = []
|
112 |
+
results = parse_search_results(search_results)
|
113 |
+
return results
|
114 |
+
|
115 |
+
|
116 |
+
#########################################################
|
117 |
+
# ArXiv API
|
118 |
+
#########################################################
|
119 |
+
def _collect_papers_arxiv(keyword, counts=3, tldr=False):
|
120 |
# Build the arXiv API query URL with the given keyword and other parameters
|
121 |
def build_query_url(keyword, results_limit=3, sort_by="relevance", sort_order="descending"):
|
122 |
base_url = "http://export.arxiv.org/api/query?"
|
|
|
142 |
title = entry.find(f"{namespace}title").text
|
143 |
link = entry.find(f"{namespace}id").text
|
144 |
summary = entry.find(f"{namespace}summary").text
|
145 |
+
summary = remove_newlines(summary)
|
146 |
|
147 |
# Extract the authors
|
148 |
authors = entry.findall(f"{namespace}author")
|
|
|
182 |
results = parse_results(content)
|
183 |
return results
|
184 |
|
185 |
+
|
186 |
+
#########################################################
|
187 |
+
# References Class
|
188 |
+
#########################################################
|
189 |
+
|
190 |
# Each `paper` is a dictionary containing (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
|
191 |
class References:
|
192 |
+
def __init__(self, load_papers=""):
|
193 |
if load_papers:
|
194 |
# todo: read a json file from the given path
|
195 |
# this could be used to support pre-defined references
|
|
|
197 |
else:
|
198 |
self.papers = []
|
199 |
|
200 |
+
def collect_papers(self, keywords_dict, method="arxiv", tldr=False):
|
201 |
"""
|
202 |
keywords_dict:
|
203 |
{"machine learning": 5, "language model": 2};
|
|
|
205 |
"""
|
206 |
match method:
|
207 |
case "arxiv":
|
208 |
+
process = _collect_papers_arxiv
|
209 |
+
case "ss":
|
210 |
+
process = _collect_papers_ss
|
211 |
case _:
|
212 |
raise NotImplementedError("Other sources have not been not supported yet.")
|
213 |
for key, counts in keywords_dict.items():
|
214 |
+
self.papers = self.papers + process(key, counts, tldr)
|
215 |
|
216 |
seen = set()
|
217 |
papers = []
|
|
|
259 |
prompts[paper["paper_id"]] = paper["abstract"]
|
260 |
return prompts
|
261 |
|
262 |
+
|
263 |
if __name__ == "__main__":
|
264 |
refs = References()
|
265 |
keywords_dict = {
|
266 |
+
"Deep Q-Networks": 15,
|
267 |
+
"Policy Gradient Methods": 24,
|
268 |
+
"Actor-Critic Algorithms": 4,
|
269 |
+
"Model-Based Reinforcement Learning": 13,
|
270 |
+
"Exploration-Exploitation Trade-off": 7
|
271 |
+
}
|
272 |
+
refs.collect_papers(keywords_dict, method="ss", tldr=True)
|
273 |
for p in refs.papers:
|
274 |
+
print(p["paper_id"])
|
275 |
+
print(len(refs.papers))
|
utils/tex_processing.py
CHANGED
@@ -24,4 +24,6 @@ def replace_title(save_to_path, title):
|
|
24 |
# check if citations are in bibtex.
|
25 |
|
26 |
|
27 |
-
# replace citations
|
|
|
|
|
|
24 |
# check if citations are in bibtex.
|
25 |
|
26 |
|
27 |
+
# replace citations
|
28 |
+
|
29 |
+
# sometimes the output may include thebibliography and bibitem . remove all of it.
|