import torch import json from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List # Initialize the FastAPI app app = FastAPI() # Model and tokenizer paths and loading model_path = "WhiteRabbitNeo/WhiteRabbitNeo-2.5-Qwen-2.5-Coder-7B" output_file_path = "/home/user/conversations.jsonl" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", load_in_4bit=False, trust_remote_code=False, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Function to generate text def generate_text(instruction): tokens = tokenizer.encode(instruction) tokens = torch.LongTensor(tokens).unsqueeze(0) tokens = tokens.to("cuda") instance = { "input_ids": tokens, "top_p": 1.0, "temperature": 0.75, "generate_len": 2048, "top_k": 50, } length = len(tokens[0]) with torch.no_grad(): rest = model.generate( input_ids=tokens, max_length=length + instance["generate_len"], use_cache=True, do_sample=True, top_p=instance["top_p"], temperature=instance["temperature"], top_k=instance["top_k"], num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, ) output = rest[0][length:] string = tokenizer.decode(output, skip_special_tokens=True) return f"{string}" # Data model for FastAPI input class UserInput(BaseModel): conversation: str user_input: str @app.post("/generate/") async def generate_response(user_input: UserInput): try: # Construct the prompt conversation = user_input.conversation llm_prompt = f"{conversation}{user_input.user_input}<|im_end|>\n<|im_start|>assistant\nSure! Let me provide a complete and a thorough answer to your question, with functional and production-ready code.\n" # Generate response answer = generate_text(llm_prompt) # Update conversation for future requests updated_conversation = f"{llm_prompt}{answer}<|im_end|>\n<|im_start|>user\n" return { "response": answer, "updated_conversation": updated_conversation } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Run the app # To start the server, use the command: uvicorn filename:app --host 0.0.0.0 --port 8000