Fred808 commited on
Commit
186de06
·
verified ·
1 Parent(s): 51c11a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -95
app.py CHANGED
@@ -1,98 +1,205 @@
1
- # Install necessary libraries
2
- # pip install transformers datasets torch
3
-
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
- from transformers import Trainer, TrainingArguments
6
- from datasets import Dataset
7
-
8
- # Step 1: Load the pre-trained GPT-2 model and tokenizer
9
- model_name = "gpt2" # You can use any GPT model, GPT-3, or other variants if you want a bigger model
10
- model = GPT2LMHeadModel.from_pretrained(model_name)
11
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
12
-
13
- # Set padding token as GPT-2 doesn't have one by default
14
- tokenizer.pad_token = tokenizer.eos_token
15
-
16
- # Step 2: Prepare your training data (Instagram algorithm and feature usage)
17
- training_data = [
18
- {
19
- "input": "How can I improve engagement on Instagram?",
20
- "output": "Engagement can be improved by posting at optimal times, using 20-30 relevant hashtags, and responding to comments quickly. Consider using reels for higher visibility."
21
- },
22
- {
23
- "input": "What are the best times to post on Instagram?",
24
- "output": "The best times to post on Instagram depend on your audience's time zone. Typically, posting during peak activity times such as early morning or late evening can lead to better engagement."
25
- },
26
- {
27
- "input": "How do I use Instagram Insights?",
28
- "output": "Go to your profile, tap the menu, and select 'Insights.' You can view metrics like reach, impressions, and engagement."
29
- },
30
- {
31
- "input": "What is the best way to use hashtags on Instagram?",
32
- "output": "Use a mix of trending, niche, and brand-specific hashtags. Aim for around 20-30 relevant hashtags per post. Research the most effective ones for your target audience."
33
- },
34
- {
35
- "input": "How can I use Instagram Stories to grow my account?",
36
- "output": "Instagram Stories can be used to engage your followers by sharing behind-the-scenes content, polls, Q&As, and other interactive elements. Consistency and engaging content are key."
37
- },
38
- ]
39
-
40
- # Step 3: Process the data into a format suitable for training
41
- def process_data(examples):
42
- # Concatenate input and output to form the training sequence
43
- return tokenizer(examples['input'] + tokenizer.eos_token + examples['output'], truncation=True, padding="max_length", max_length=128)
44
-
45
- # Convert the training data into a dataset
46
- dataset = Dataset.from_dict(training_data)
47
- dataset = dataset.map(process_data, batched=True)
48
-
49
- # Step 4: Split the dataset into training and validation sets
50
- train_dataset = dataset.train_test_split(test_size=0.1)["train"]
51
- val_dataset = dataset.train_test_split(test_size=0.1)["test"]
52
-
53
- # Step 5: Define the training arguments
54
- training_args = TrainingArguments(
55
- output_dir="./gpt2-instagram-model", # Directory to save the model
56
- evaluation_strategy="epoch", # Evaluate at the end of each epoch
57
- learning_rate=5e-5, # Learning rate for fine-tuning
58
- per_device_train_batch_size=4, # Batch size for training
59
- per_device_eval_batch_size=4, # Batch size for evaluation
60
- num_train_epochs=3, # Number of training epochs
61
- weight_decay=0.01, # Weight decay for regularization
62
- logging_dir='./logs', # Log directory
63
- logging_steps=200, # Log every 200 steps
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Step 6: Initialize the Trainer
67
- trainer = Trainer(
68
- model=model, # The model we are training
69
- args=training_args, # Training arguments
70
- train_dataset=train_dataset, # Training dataset
71
- eval_dataset=val_dataset, # Validation dataset
72
- )
73
 
74
- # Step 7: Train the model
75
- trainer.train()
76
-
77
- # Step 8: Evaluate the model after training
78
- results = trainer.evaluate()
79
- print("Evaluation Results:", results)
80
-
81
- # Step 9: Save the model and tokenizer
82
- model.save_pretrained("./gpt2-instagram-model")
83
- tokenizer.save_pretrained("./gpt2-instagram-model")
84
-
85
- # Step 10: Use the trained model to generate responses
86
- def generate_response(input_text):
87
- # Encode the input text and generate a response
88
- inputs = tokenizer.encode(input_text, return_tensors="pt")
89
- output = model.generate(inputs, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2)
90
-
91
- # Decode and return the response
92
- response = tokenizer.decode(output[0], skip_special_tokens=True)
93
- return response
94
-
95
- # Example: Generate a response
96
- input_text = "How can I improve engagement on Instagram?"
97
- response = generate_response(input_text)
98
- print("Generated Response:", response)
 
1
+ import re
2
+ import json
3
+ import numpy as np
4
+ import faiss
5
+ from flask import Flask, request, jsonify
6
+ from transformers import (
7
+ pipeline,
8
+ AutoModelForSequenceClassification,
9
+ AutoTokenizer,
10
+ AutoModelForSeq2SeqLM,
11
+ AutoModelForCausalLM,
12
+ T5Tokenizer,
13
+ T5ForConditionalGeneration,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
+ from sentence_transformers import SentenceTransformer
16
+ from bertopic import BERTopic
17
+ from datasets import load_dataset
18
+
19
+ # Preprocessing function
20
+ def preprocess_text(text):
21
+ """
22
+ Cleans and tokenizes text.
23
+ """
24
+ text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) # Remove URLs
25
+ text = re.sub(r"\s+", " ", text).strip() # Remove extra spaces
26
+ text = re.sub(r"[^\w\s]", "", text) # Remove punctuation
27
+ return text.lower()
28
+
29
+
30
+ # Content Classification Model
31
+ class ContentClassifier:
32
+ def __init__(self, model_name="bert-base-uncased"):
33
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
35
+ self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
36
+
37
+ def classify(self, text):
38
+ """
39
+ Classifies text into predefined categories.
40
+ """
41
+ result = self.pipeline(text)
42
+ return result
43
+
44
+
45
+ # Relevance Detection Model
46
+ class RelevanceDetector:
47
+ def __init__(self, model_name="bert-base-uncased"):
48
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
49
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
50
+ self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
51
+
52
+ def detect_relevance(self, text, threshold=0.5):
53
+ """
54
+ Detects whether a text is relevant to a specific domain.
55
+ """
56
+ result = self.pipeline(text)
57
+ return result[0]["label"] == "RELEVANT" and result[0]["score"] > threshold
58
+
59
+
60
+ # Topic Extraction Model using BERTopic
61
+ class TopicExtractor:
62
+ def __init__(self):
63
+ self.model = BERTopic()
64
+
65
+ def extract_topics(self, documents):
66
+ """
67
+ Extracts topics from a list of documents.
68
+ """
69
+ topics, probs = self.model.fit_transform(documents)
70
+ return self.model.get_topic_info()
71
+
72
+
73
+ # Summarization Model
74
+ class Summarizer:
75
+ def __init__(self, model_name="t5-small"):
76
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
77
+ self.model = T5ForConditionalGeneration.from_pretrained(model_name)
78
+
79
+ def summarize(self, text, max_length=100):
80
+ """
81
+ Summarizes a given text.
82
+ """
83
+ inputs = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
84
+ summary_ids = self.model.generate(inputs, max_length=max_length, min_length=25, length_penalty=2.0, num_beams=4)
85
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
86
+ return summary
87
+
88
+
89
+ # Search and Recommendation Model using FAISS
90
+ class SearchEngine:
91
+ def __init__(self, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
92
+ self.model = SentenceTransformer(embedding_model)
93
+ self.index = None
94
+ self.documents = []
95
+
96
+ def build_index(self, docs):
97
+ """
98
+ Builds a FAISS index for document retrieval.
99
+ """
100
+ self.documents = docs
101
+ embeddings = self.model.encode(docs, convert_to_tensor=True, show_progress_bar=True)
102
+ self.index = faiss.IndexFlatL2(embeddings.shape[1])
103
+ self.index.add(embeddings.cpu().detach().numpy())
104
+
105
+ def search(self, query, top_k=5):
106
+ """
107
+ Searches the index for the top_k most relevant documents.
108
+ """
109
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
110
+ distances, indices = self.index.search(query_embedding.cpu().detach().numpy().reshape(1, -1), top_k)
111
+ return [(self.documents[i], distances[0][i]) for i in indices[0]]
112
+
113
+
114
+ # Conversational Model using GPT-2
115
+ class Chatbot:
116
+ def __init__(self, model_name="gpt2"):
117
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
118
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
119
+
120
+ def generate_response(self, prompt, max_length=50):
121
+ """
122
+ Generates a response to a user query using GPT-2.
123
+ """
124
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt")
125
+ outputs = self.model.generate(inputs, max_length=max_length, num_return_sequences=1)
126
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
127
+ return response
128
+
129
+
130
+ # Flask API for Chatbot Integration
131
+ app = Flask(__name__)
132
+
133
+ # Initialize models
134
+ classifier = ContentClassifier()
135
+ relevance_detector = RelevanceDetector()
136
+ summarizer = Summarizer()
137
+ search_engine = SearchEngine()
138
+ topic_extractor = TopicExtractor()
139
+ chatbot = Chatbot()
140
+
141
+ # Load PleIAs/YouTube-Commons dataset
142
+ def load_youtube_data():
143
+ dataset = load_dataset("PleIAs/YouTube-Commons")
144
+ return dataset["train"]["text"] # Adjust based on dataset structure
145
+
146
+
147
+ # Preprocess and build search index
148
+ youtube_data = load_youtube_data()
149
+ search_engine.build_index(youtube_data)
150
+
151
+ # API Endpoints
152
+ @app.route("/classify", methods=["POST"])
153
+ def classify():
154
+ text = request.json.get("text", "")
155
+ if not text:
156
+ return jsonify({"error": "No text provided"}), 400
157
+ result = classifier.classify(text)
158
+ return jsonify(result)
159
+
160
+
161
+ @app.route("/relevance", methods=["POST"])
162
+ def relevance():
163
+ text = request.json.get("text", "")
164
+ if not text:
165
+ return jsonify({"error": "No text provided"}), 400
166
+ relevant = relevance_detector.detect_relevance(text)
167
+ return jsonify({"relevant": relevant})
168
+
169
+
170
+ @app.route("/summarize", methods=["POST"])
171
+ def summarize():
172
+ text = request.json.get("text", "")
173
+ if not text:
174
+ return jsonify({"error": "No text provided"}), 400
175
+ summary = summarizer.summarize(text)
176
+ return jsonify({"summary": summary})
177
+
178
+
179
+ @app.route("/search", methods=["POST"])
180
+ def search():
181
+ query = request.json.get("query", "")
182
+ if not query:
183
+ return jsonify({"error": "No query provided"}), 400
184
+ results = search_engine.search(query)
185
+ return jsonify({"results": results})
186
+
187
+
188
+ @app.route("/topics", methods=["POST"])
189
+ def topics():
190
+ result = topic_extractor.extract_topics(youtube_data)
191
+ return jsonify({"topics": result.to_dict()})
192
+
193
+
194
+ @app.route("/chat", methods=["POST"])
195
+ def chat():
196
+ prompt = request.json.get("prompt", "")
197
+ if not prompt:
198
+ return jsonify({"error": "No prompt provided"}), 400
199
+ response = chatbot.generate_response(prompt)
200
+ return jsonify({"response": response})
201
 
 
 
 
 
 
 
 
202
 
203
+ # Start the Flask API
204
+ if __name__ == "__main__":
205
+ app.run(debug=True)