paper-central / pr_paper_central_tab.py
IAMJB's picture
update api
4e0c371
raw
history blame
11.2 kB
import gradio as gr
from typing import Optional
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, CommitOperationAdd
import json
import os
import requests
# PR function remains the same
def create_pr_in_hf_dataset(new_entry, oauth_token: gr.OAuthToken):
# Dataset and filename
REPO_ID = 'IAMJB/paper-central-pr'
FILENAME = 'data.json'
# Initialize HfApi
api = HfApi()
token = oauth_token.token
# Ensure the repository exists and has an initial empty data.json if not present
try:
# Create the repository if it doesn't exist
api.create_repo(repo_id=REPO_ID, token=token, repo_type='dataset', exist_ok=True)
# Check if data.json exists; if not, create it with an empty list
files = api.list_repo_files(REPO_ID, repo_type='dataset', token=token)
if FILENAME not in files:
# Initialize with empty list
empty_data = []
temp_filename = 'temp_data.json'
with open(temp_filename, 'w') as f:
json.dump(empty_data, f)
commit = CommitOperationAdd(path_in_repo=FILENAME, path_or_fileobj=temp_filename)
api.create_commit(
repo_id=REPO_ID,
operations=[commit],
commit_message="Initialize data.json",
repo_type="dataset",
token=token,
)
os.remove(temp_filename)
except Exception as e:
return f"Error creating or accessing repository: {e}"
# Download existing data from the dataset
try:
# Download the existing data.json file
local_filepath = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type='dataset', token=token)
with open(local_filepath, 'r') as f:
data = json.load(f)
except Exception as e:
print(f"Error downloading existing data: {e}")
data = []
# Add the new entry
data.append(new_entry)
# Save to temporary file
temp_filename = 'temp_data.json'
with open(temp_filename, 'w') as f:
json.dump(data, f, indent=2)
# Create commit operation
commit = CommitOperationAdd(path_in_repo=FILENAME, path_or_fileobj=temp_filename)
# Create PR
try:
res = api.create_commit(
repo_id=REPO_ID,
operations=[commit],
commit_message=f"Add new entry for arXiv ID {new_entry['arxiv_id']}",
repo_type="dataset",
create_pr=True,
token=token,
)
pr_url = res.pr_url
os.remove(temp_filename)
except Exception as e:
print(f"Error creating PR: {e}")
pr_url = "Error creating PR."
return pr_url
def pr_paper_central_tab(paper_central_df):
with gr.Column():
gr.Markdown("## PR Paper-central")
# Message to prompt user to log in
login_prompt = gr.Markdown("Please log in to proceed.", visible=False)
# Input for arXiv ID
arxiv_id_input = gr.Textbox(label="Enter arXiv ID")
arxiv_id_button = gr.Button("Submit")
# Message to display errors or information
message = gr.Markdown("", visible=False)
# Button to create paper page
create_paper_page_button = gr.Button("Create Paper Page", visible=False,
icon="https://huggingface.co/front/assets/huggingface_logo-noborder.svg")
# Define the fields dynamically (removed 'paper_page')
fields = [
{'name': 'github', 'label': 'GitHub URL'},
{'name': 'conference_name', 'label': 'Conference Name'},
{'name': 'type_', 'label': 'Type'}, # Renamed from 'type' to 'type_'
{'name': 'proceedings', 'label': 'Proceedings'},
# Add or remove fields here as needed
]
input_fields = {}
for field in fields:
input_fields[field['name']] = gr.Textbox(label=field['label'], visible=False)
# Button to create PR
create_pr_button = gr.Button("Create PR", visible=False,
icon="https://huggingface.co/front/assets/huggingface_logo-noborder.svg")
# Output message
pr_message = gr.Markdown("", visible=False)
# Loading message
loading_message = gr.Markdown("Creating PR, please wait...", visible=False)
# Function to handle arxiv_id submission and check login
def check_login_and_handle_arxiv_id(arxiv_id, oauth_token: Optional[gr.OAuthToken]):
if oauth_token is None:
# Not logged in
return [gr.update(value="Please log in to proceed.", visible=True)] + \
[gr.update(visible=False) for _ in fields] + \
[gr.update(visible=False)] + [gr.update(visible=False)] + [
gr.update(visible=False)] # create_pr_button, create_paper_page_button, pr_message
else:
ACCESS_TOKEN = os.getenv('paper_space_pr_token')
access_token_exists = ACCESS_TOKEN is not None
# Prepare the updates list
updates = []
if arxiv_id not in paper_central_df['arxiv_id'].values:
# arXiv ID not found
updates.append(gr.update(value="arXiv ID not found. You can create a paper page.", visible=True))
# Input fields are empty
for field in fields:
updates.append(gr.update(value="", visible=True))
updates.append(gr.update(visible=True)) # create_pr_button
# Show 'Create Paper Page' button if access token exists
updates.append(gr.update(visible=access_token_exists)) # create_paper_page_button
updates.append(gr.update(visible=False)) # pr_message
else:
# arXiv ID found
row = paper_central_df[paper_central_df['arxiv_id'] == arxiv_id].iloc[0]
paper_page = row.get('paper_page', "")
if not paper_page:
# paper_page missing or empty
updates.append(gr.update(value="Paper page not found. You can create one.", visible=True))
for field in fields:
value = row.get(field['name'], "")
updates.append(gr.update(value=value, visible=True))
updates.append(gr.update(visible=True)) # create_pr_button
updates.append(gr.update(visible=access_token_exists)) # create_paper_page_button
updates.append(gr.update(visible=False)) # pr_message
else:
# paper_page exists
updates.append(gr.update(value="", visible=False)) # message
for field in fields:
value = row.get(field['name'], "")
updates.append(gr.update(value=value, visible=True))
updates.append(gr.update(visible=True)) # create_pr_button
updates.append(gr.update(visible=False)) # create_paper_page_button
updates.append(gr.update(visible=False)) # pr_message
return updates
arxiv_id_button.click(
fn=check_login_and_handle_arxiv_id,
inputs=[arxiv_id_input],
outputs=[message] + [input_fields[field['name']] for field in fields] + [create_pr_button,
create_paper_page_button,
pr_message],
api_name=False
)
# Function to create PR
def create_pr(arxiv_id, github, conference_name, type_, proceedings,
oauth_token: Optional[gr.OAuthToken] = None):
if oauth_token is None:
return gr.update(value="Please log in first.", visible=True)
else:
new_entry = {
'arxiv_id': arxiv_id,
'github': github,
'conference_name': conference_name,
'type': type_,
'proceedings': proceedings
}
# Now add this to the dataset and create a PR
pr_url = create_pr_in_hf_dataset(new_entry, oauth_token)
return gr.update(value=f"PR created: {pr_url}", visible=True)
create_pr_button.click(
fn=lambda: gr.update(visible=True), # Show loading message
inputs=[],
outputs=[loading_message],
api_name=False
).then(
fn=create_pr,
inputs=[arxiv_id_input] + [input_fields[field['name']] for field in fields],
outputs=[pr_message],
api_name=False
).then(
fn=lambda: gr.update(visible=False), # Hide loading message
inputs=[],
outputs=[loading_message],
api_name=False
)
# Function to create paper page
def create_paper_page(arxiv_id):
# Implement the API calls to create the paper page
INDEX_URL = "https://huggingface.co/api/papers/index"
SUBMIT_URL = "https://huggingface.co/api/papers/submit"
ACCESS_TOKEN = os.getenv('paper_space_pr_token')
if not ACCESS_TOKEN:
return gr.update(value="Server error: Access token not found.", visible=True)
# Index the paper
payload_index = {"arxivId": arxiv_id}
headers = {
"Authorization": f"Bearer {ACCESS_TOKEN}",
"Content-Type": "application/json"
}
response_index = requests.post(INDEX_URL, json=payload_index, headers=headers)
if response_index.status_code == 200:
# Successfully indexed, now submit the paper
paper_id = arxiv_id # Assuming paperId is the same as arxivId
payload_submit = {
"paperId": paper_id,
"comment": "",
"mediaUrls": []
}
response_submit = requests.post(SUBMIT_URL, json=payload_submit, headers=headers)
if response_submit.status_code == 200:
return gr.update(value="Paper page created successfully.", visible=True)
else:
return gr.update(
value=f"Failed to submit paper: {response_submit.status_code}, {response_submit.text}",
visible=True)
else:
return gr.update(value=f"Failed to index paper: {response_index.status_code}, {response_index.text}",
visible=True)
create_paper_page_button.click(
fn=create_paper_page,
inputs=[arxiv_id_input],
outputs=[message],
api_name=False
)