Spaces:
Running
Running
shaocongma
Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.
94dc00e
''' | |
This script is only used for service-side host. | |
''' | |
import boto3 | |
import os, time | |
from wrapper import generator_wrapper | |
from sqlalchemy import create_engine, Table, MetaData, update, select | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy import inspect | |
QUEUE_URL = os.getenv('QUEUE_URL') | |
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID') | |
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY') | |
BUCKET_NAME = os.getenv('BUCKET_NAME') | |
DB_STRING = os.getenv('DATABASE_STRING') | |
# Create engine | |
ENGINE = create_engine(DB_STRING) | |
SESSION = sessionmaker(bind=ENGINE) | |
####################################################################################################################### | |
# Amazon SQS Handler | |
####################################################################################################################### | |
def get_sqs_client(): | |
sqs = boto3.client('sqs', region_name="us-east-2", | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_ACCESS_KEY) | |
return sqs | |
def receive_message(): | |
sqs = get_sqs_client() | |
message = sqs.receive_message(QueueUrl=QUEUE_URL) | |
if message.get('Messages') is not None: | |
receipt_handle = message['Messages'][0]['ReceiptHandle'] | |
else: | |
receipt_handle = None | |
return message, receipt_handle | |
def delete_message(receipt_handle): | |
sqs = get_sqs_client() | |
response = sqs.delete_message(QueueUrl=QUEUE_URL, ReceiptHandle=receipt_handle) | |
return response | |
####################################################################################################################### | |
# AWS S3 Handler | |
####################################################################################################################### | |
def get_s3_client(): | |
access_key_id = os.getenv('AWS_ACCESS_KEY_ID') | |
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') | |
session = boto3.Session( | |
aws_access_key_id=access_key_id, | |
aws_secret_access_key=secret_access_key, | |
) | |
s3 = session.resource('s3') | |
bucket = s3.Bucket(BUCKET_NAME) | |
return s3, bucket | |
def upload_file(file_name, target_name=None): | |
s3, _ = get_s3_client() | |
if target_name is None: | |
target_name = file_name | |
s3.meta.client.upload_file(Filename=file_name, Bucket=BUCKET_NAME, Key=target_name) | |
print(f"The file {file_name} has been uploaded!") | |
def download_file(file_name): | |
""" Download `file_name` from the bucket. | |
Bucket (str) – The name of the bucket to download from. | |
Key (str) – The name of the key to download from. | |
Filename (str) – The path to the file to download to. | |
""" | |
s3, _ = get_s3_client() | |
s3.meta.client.download_file(Bucket=BUCKET_NAME, Key=file_name, Filename=os.path.basename(file_name)) | |
print(f"The file {file_name} has been downloaded!") | |
####################################################################################################################### | |
# AWS SQL Handler | |
####################################################################################################################### | |
def modify_status(task_id, new_status): | |
session = SESSION() | |
metadata = MetaData() | |
task_to_update = task_id | |
task_table = Table('task', metadata, autoload_with=ENGINE) | |
stmt = select(task_table).where(task_table.c.task_id == task_to_update) | |
# Execute the statement | |
with ENGINE.connect() as connection: | |
result = connection.execute(stmt) | |
# Fetch the first result (if exists) | |
task_data = result.fetchone() | |
# If user_data is not None, the user exists and we can update the password | |
if task_data: | |
# Update statement | |
stmt = ( | |
update(task_table). | |
where(task_table.c.task_id == task_to_update). | |
values(status=new_status) | |
) | |
# Execute the statement and commit | |
result = connection.execute(stmt) | |
connection.commit() | |
# Close the session | |
session.close() | |
####################################################################################################################### | |
# Pipline | |
####################################################################################################################### | |
def pipeline(message_count=0, query_interval=10): | |
# status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed | |
# Query a message from SQS | |
msg, handle = receive_message() | |
if handle is None: | |
print("No message in SQS. ") | |
time.sleep(query_interval) | |
else: | |
print("===============================================================================================") | |
print(f"MESSAGE COUNT: {message_count}") | |
print("===============================================================================================") | |
config_s3_path = msg['Messages'][0]['Body'] | |
config_s3_dir = os.path.dirname(config_s3_path) | |
config_local_path = os.path.basename(config_s3_path) | |
task_id, _ = os.path.splitext(config_local_path) | |
print("Initializing ...") | |
print("Configuration file on S3: ", config_s3_path) | |
print("Configuration file on S3 (Directory): ", config_s3_dir) | |
print("Local file path: ", config_local_path) | |
print("Task id: ", task_id) | |
print(f"Success in receiving message: {msg}") | |
print(f"Configuration file path: {config_s3_path}") | |
# Process the downloaded configuration file | |
download_file(config_s3_path) | |
modify_status(task_id, 1) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed | |
delete_message(handle) | |
print(f"Success in the initialization. Message deleted.") | |
print("Running ...") | |
# try: | |
zip_path = generator_wrapper(config_local_path) | |
# Upload the generated file to S3 | |
upload_to = os.path.join(config_s3_dir, zip_path).replace("\\", "/") | |
print("Local file path (ZIP): ", zip_path) | |
print("Upload to S3: ", upload_to) | |
upload_file(zip_path, upload_to) | |
modify_status(task_id, 2) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed, 4 - deleted | |
print(f"Success in generating the paper.") | |
# Complete. | |
print("Task completed.") | |
def initialize_everything(): | |
# Clear S3 | |
# Clear SQS | |
pass | |
if __name__ == "__main__": | |
pipeline() | |