from flask import Flask, jsonify, request import requests import redis import json from flask_cors import CORS import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer from history import load_dataset, get_unique_next_words_from_dataset from dotenv import load_dotenv import os os.environ["TRANSFORMERS_CACHE"] = "/code/.cache" from typing import List, Dict, Optional, Union import logging from most_repeted_sentences import sentences_name, get_most_repeated_sentences, save_most_repeated_sentences load_dotenv() app = Flask(__name__) CORS(app) # Setup logging logging.basicConfig(level=logging.ERROR) logger = logging.getLogger(__name__) # Pixabay API setup PIXABAY_URL = "https://pixabay.com/api/?key=${pixabayApiKey}&q=${word}&image_type=all&per_page=3" PIXABAY_API_KEY =os.getenv("API_kEY") # setup redis redis_client = redis.Redis( host='redis-18594.c301.ap-south-1-1.ec2.redns.redis-cloud.com', port=18594, decode_responses=True, username="default", password=os.getenv("REDIS_PASSWORD") ) print(redis_client) # Load the model and tokenizer once when the app starts model = GPT2LMHeadModel.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Global variables predicted_words = [] append_list = [] global_count=0 default_predicted_words = ['i', 'what', 'hello', 'where', 'who', 'how', 'can', 'is', 'are', 'could', 'would', 'may', 'can', 'please', 'will', 'shall', 'did', 'have', 'has', 'had', 'am', 'were', 'was', 'should', 'might', 'must', 'please', 'you', 'he', 'she', 'they', 'it', 'this', 'that', 'these', 'those', 'let', 'we', 'my', 'your', 'his', 'her', 'their', 'our', 'the', 'there', 'come', 'go', 'bring', 'take', 'give', 'help', 'want', 'need', 'eat', 'drink', 'sleep', 'play', 'run', 'walk', 'talk', 'call', 'find', 'make', 'see', 'get', 'know'] def generate_predicted_words(input_text,index =0): # Load the dataset dataset_name = "dataset.txt" dataset = load_dataset(dataset_name) history_next_text = get_unique_next_words_from_dataset(input_text, dataset) # Tokenize input inputs = tokenizer(input_text, return_tensors="pt").to(model.device) # Forward pass through the model with torch.no_grad(): outputs = model(**inputs, return_dict=True) logits = outputs.logits # Get the logits for the last token last_token_logits = logits[:, -1, :] probabilities = torch.softmax(last_token_logits, dim=-1) # Get the top 50 most probable next tokens top_50_probs, top_50_indices = torch.topk(probabilities, 50) top_50_tokens = [tokenizer.decode([idx], clean_up_tokenization_spaces=False) for idx in top_50_indices[0]] words = [] removable_words = [' (', ' a', "'s", ' "', ' -', ' as', " '", "the", " the", "an", " an", "<|endoftext|>", '’d','’m', '’ll','t’s' ,] for token in top_50_tokens: if len(token) != 1 and token not in removable_words: words.append(token.strip().lower()) return history_next_text + words # fetch from pixabay def fetch_images_from_pixabay(query: str) -> dict: # print("yo query ko lagi fetch hudai xa..." , query) response = requests.get(PIXABAY_URL, params={ "key": PIXABAY_API_KEY, "q": query, "image_type": "all", "per_page": "3" }) # print("this is from pixabay haita====>" , response.json()) if response.status_code != 200: return {"error": "Failed to fetch data from Pixabay"} return response.json() # fetch images api @app.route('/api/images', methods=['GET']) def get_images(): query = request.args.get('query') correspond_id=request.args.get('id') print("yo chai id hai" , correspond_id) print("yo chai query ho hai" , query) if not query: return jsonify({"error": "Query parameter is required"}), 400 # Check Redis cache for images cached_images = redis_client.get('image_cache') # print("yo ho chaiyeko cached heloooooooooooooooooooooo", cached_images) if cached_images: cached_images = json.loads(cached_images) # Convert JSON string to dictionary # print("cached_img" , cached_images) for i in cached_images['hits']: # print("lagalagalag------------>",i.get('query_id')) # compare the id of the already queried id and id of the query currently if(i.get('query_id')==correspond_id): print("Fetching from cache-------------->" , i['previewURL']) return jsonify(i['previewURL']) # print("Fetching from Pixabay") # Fetch from Pixabay if not in cache # Fetch from Pixabay if not in cache data = fetch_images_from_pixabay(query) if "error" in data: return jsonify(data), 500 for i in data['hits']: i['query_id']=correspond_id # print("i bhitra haita",i['query_id']) # get the total images i.e previously cached images and current images. if cached_images: data['hits'] = cached_images['hits'] + data['hits'] data['total'] = cached_images['total'] + data['total'] # Cache the result in Redis for 1hrs redis_client.setex('image_cache', 86400, json.dumps(data)) print("image from Pixabay-------------->" , data['hits'].pop()['previewURL']) return jsonify(data['hits'].pop()['previewURL']) @app.route('/api/display_words', methods=['GET']) def get_display_words(): try: count = int(request.args.get('count', global_count)) # Default to 0 if 'count' is not provided print(type(count)) except ValueError: return jsonify({"error": "Invalid count value"}), 400 print("Count:", count) start_index = 9 * count end_index = start_index + 9 print("Start index:", start_index) print("End index:", end_index) if start_index >= len(default_predicted_words): # Reset if out of bounds count = 0 start_index = 0 end_index = 9 display_words = default_predicted_words[start_index:end_index] print("Display words:", display_words) return display_words # @app.route('/api/scenerio', methods=['POST']) # # @app.route('/api/select_location', methods=['GET']) # def scenerio(): # # Get the query parameter from the URL, e.g., /api/select_location?place=home # place = request.args.get('place') # if place == "home": # display_words = default_predicted_words[start_index:end_index] # return jsonify(display_words) # @app.route('/api/huu', methods=['GET']) # def fetch_most_repeated_sentences(): # Ensure the function name is unique # try: # with open('most_repeated_sentences.txt', 'r') as file: # # Read the first 5 lines # lines = [] # for _ in range(5): # text = file.readline().strip().split(":")[0] # print(text) # lines.append(text) # # lines = [file.readline().strip().split(':')[0] for _ in range(5)] # return jsonify(lines), 200 # Return the lines as JSON with a 200 OK status # except FileNotFoundError: # return jsonify({"error": "File not found."}), 404 # Handle file not found error # except Exception as e: # return jsonify({"error": str(e)}), 500 # Handle other potential errors @app.route('/api/most_repeated_sentence', methods=['GET']) def fetch_most_repeated_sentences(): try: sentences = [] with open('most_repeated_sentences.txt', 'r') as file: for line in file: line = line.strip() if ':' in line: # Check if line contains the separator try: sentence, count = line.rsplit(':', 1) # Split from right side count = int(count.strip()) sentences.append((sentence.strip(), count)) except (ValueError, IndexError): continue # Skip invalid lines # Sort sentences by count in descending order sorted_sentences = sorted(sentences, key=lambda x: x[1], reverse=True) # Get top 5 sentences only top_5_sentences = [sentence[0] for sentence in sorted_sentences[:5]] return jsonify(top_5_sentences) except FileNotFoundError: return jsonify({"error": "File not found"}), 404 except Exception as e: logger.error(f"Error in fetch_most_repeated_sentences: {e}") return jsonify({"error": "Internal server error"}), 500 @app.route('/api/guu', methods=['POST']) def predict_words(): global predicted_words, append_list , global_count try: data = request.get_json() print("Received data:", data) if not isinstance(data, dict): return jsonify({'error': 'Invalid JSON format'}), 400 input_text = data.get('item', '').strip() # Ensure we are checking the stripped input # # Handle case when input_text is "1" # if input_text == "1": # print("Resetting append_list") # append_list = [] # Reset the append list # return jsonify(default_predicted_words[:9]) # Return the default words # Handle reset request if input_text == "1": with open('dataset.txt', 'a') as file: file.write(' '.join(append_list) + '\n') append_list = [] global_count = 0 sentence= sentences_name('dataset.txt') repeated_sentences = get_most_repeated_sentences(sentence) print("Most repeated sentences:", repeated_sentences) save_most_repeated_sentences(repeated_sentences, 'most_repeated_sentences.txt') return jsonify(default_predicted_words[:9]) if not input_text: return jsonify({'error': 'No input text provided'}), 400 append_list.append(input_text) print("Current append list:", append_list) combined_input = ' '.join(append_list) print("Combined input for prediction:", combined_input) predicted_words = generate_predicted_words(combined_input) print("Predicted words:", predicted_words) return jsonify(predicted_words[:9]) except Exception as e: print(f"An error occurred: {str(e)}") # Log the error message return jsonify({'error': str(e)}), 500 application = app if __name__ == '__main__': application.run()