File size: 3,518 Bytes
dcc13e8
 
8ddbef4
 
dcc13e8
8ddbef4
 
 
 
 
dcc13e8
8ddbef4
 
dcc13e8
 
8ddbef4
dcc13e8
8ddbef4
 
dcc13e8
8ddbef4
 
 
 
 
 
dcc13e8
8ddbef4
 
dcc13e8
 
8ddbef4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210df67
8ddbef4
 
 
 
210df67
a6df729
210df67
 
 
a6df729
210df67
a6df729
210df67
 
 
 
 
 
 
 
 
a6df729
210df67
 
 
 
 
 
a6df729
210df67
 
a6df729
210df67
 
 
a6df729
210df67
8ddbef4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
from langchain_community.document_loaders import UnstructuredPDFLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline
import torch
import os
from datasets import Dataset
import pandas as pd
import re

# Set up page
st.set_page_config(
    page_title="Tweet Style Cloning",
    page_icon="🐦",
    layout="centered"
)
st.title("🐦 Clone Tweet Style from PDF")

# Step 1: Upload PDF
uploaded_file = st.file_uploader("Upload a PDF with tweets")

if uploaded_file is not None:
    # Step 2: Extract text from PDF
    def load_pdf_text(file_path):
        loader = UnstructuredPDFLoader(file_path)
        documents = loader.load()
        return " ".join([doc.page_content for doc in documents])

    # Save the uploaded PDF file temporarily
    with open("uploaded_tweets.pdf", "wb") as f:
        f.write(uploaded_file.getbuffer())

    # Extract text from PDF
    extracted_text = load_pdf_text("uploaded_tweets.pdf")

    # Step 3: Preprocess text to separate each tweet (assuming tweets end with newline)
    tweets = re.split(r'\n+', extracted_text)
    tweets = [tweet.strip() for tweet in tweets if len(tweet.strip()) > 0]

    # Display a few sample tweets for verification
    st.write("Sample Tweets Extracted:")
    st.write(tweets[:5])

    # Step 4: Fine-tune a model on the extracted tweets
    def fine_tune_model(tweets):
        # Convert tweets to a DataFrame and Dataset
        df = pd.DataFrame(tweets, columns=["text"])
        tweet_dataset = Dataset.from_pandas(df)

        # Load model and tokenizer
        model_name = "gpt2"  # Replace with a suitable model if needed
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)

        tokenizer.pad_token = tokenizer.eos_token

        # Tokenize the dataset
        def tokenize_function(examples):
            return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

        tokenized_tweets = tweet_dataset.map(tokenize_function, batched=True)

        # Training arguments
        training_args = TrainingArguments(
            output_dir="./fine_tuned_tweet_model",
            per_device_train_batch_size=4,
            num_train_epochs=3,
            save_steps=10_000,
            save_total_limit=1,
            logging_dir='./logs',
        )

        # Initialize the Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_tweets,
        )

        # Fine-tune the model
        trainer.train()

        # Save the fine-tuned model
        model.save_pretrained("fine_tuned_tweet_model")
        tokenizer.save_pretrained("fine_tuned_tweet_model")

        return model, tokenizer

    # Trigger fine-tuning and notify user
    with st.spinner("Fine-tuning model..."):
        model, tokenizer = fine_tune_model(tweets)
    st.success("Model fine-tuned successfully!")

    # Step 5: Set up text generation
    tweet_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

    # Generate a new tweet based on user input
    prompt = st.text_input("Enter a prompt for a new tweet in the same style:")
    if prompt:
        with st.spinner("Generating tweet..."):
            generated_tweet = tweet_generator(prompt, max_length=50, num_return_sequences=1)
            st.write("Generated Tweet:")
            st.write(generated_tweet[0]["generated_text"])