saifeddinemk commited on
Commit
9b9a132
·
1 Parent(s): 6c70ef6

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +62 -51
app.py CHANGED
@@ -1,68 +1,79 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from llama_cpp import Llama
4
- from functools import lru_cache
5
- import asyncio
6
  import uvicorn
7
 
8
  # Initialize FastAPI app
9
  app = FastAPI()
10
 
11
- # Lazy load the Llama model with float16 precision
12
- @lru_cache(maxsize=1)
13
- def load_model():
14
- try:
15
- return Llama.from_pretrained(
16
- repo_id="QuantFactory/SecurityLLM-GGUF",
17
- filename="SecurityLLM.Q8_0.gguf",
18
- torch_dtype="float16" # Specify FP16 precision
19
- )
20
- except Exception as e:
21
- raise RuntimeError(f"Failed to load model: {e}")
 
 
22
 
23
- # Define request model for log data
24
- class LogRequest(BaseModel):
 
 
 
 
 
 
 
25
  log_data: str
26
 
27
- # Define response model
28
- class AnalysisResponse(BaseModel):
29
  analysis: str
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Define the route for security log analysis
32
- @app.post("/analyze_security_logs", response_model=AnalysisResponse)
33
- async def analyze_security_logs(request: LogRequest):
34
- llm = load_model()
35
  try:
36
- # Security-focused prompt
37
- prompt = (
38
- "You are an advanced cybersecurity analysis assistant. Carefully analyze the following network log data for any indicators of malicious or suspicious activity. "
39
- "Specifically, look for patterns or unusual events that might suggest unauthorized access, data exfiltration, suspicious IP addresses, frequent access attempts, "
40
- "or other anomalies. Provide a detailed analysis that includes:\n\n"
41
- "1. A list of any suspicious IP addresses with explanations of why they are flagged as such.\n"
42
- "2. Any patterns or sequences in the logs that could indicate an ongoing attack or probing activity.\n"
43
- "3. Identified unauthorized access attempts, with details on the methods or vulnerabilities being exploited, if detectable.\n"
44
- "4. Recommendations on immediate actions or mitigations the system administrator should take to address any identified threats.\n"
45
- "5. An assessment of the overall security posture based on the log data, including any potential weaknesses or areas for improvement.\n\n"
46
- "Log Data:\n"
47
- f"{request.log_data}\n\n"
48
- "Please provide a comprehensive response addressing all points in detail."
49
- )
50
-
51
- # Generate response with controlled max tokens
52
- response = await asyncio.to_thread(
53
- llm.create_chat_completion,
54
- messages=[
55
- {
56
- "role": "user",
57
- "content": prompt
58
- }
59
- ],
60
- max_tokens=1024 # Adjust to limit the response length
61
- )
62
-
63
- # Extract and return the analysis text
64
- analysis_text = response["choices"][0]["message"]["content"]
65
- return AnalysisResponse(analysis=analysis_text)
66
  except Exception as e:
67
  raise HTTPException(status_code=500, detail=str(e))
68
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer
 
5
  import uvicorn
6
 
7
  # Initialize FastAPI app
8
  app = FastAPI()
9
 
10
+ # Configure and load the quantized model
11
+ model_id = 'model_result'
12
+
13
+ bnb_config = BitsAndBytesConfig(
14
+ load_in_4bit=True,
15
+ bnb_4bit_quant_type="nf4",
16
+ bnb_4bit_compute_dtype=torch.bfloat16,
17
+ bnb_4bit_use_double_quant=True,
18
+ )
19
+
20
+ # Load tokenizer and model with 4-bit quantization settings
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ tokenizer.pad_token = tokenizer.eos_token
23
 
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ quantization_config=bnb_config,
27
+ device_map="auto",
28
+ )
29
+ model.eval()
30
+
31
+ # Define request and response models
32
+ class SecurityLogRequest(BaseModel):
33
  log_data: str
34
 
35
+ class SecurityAnalysisResponse(BaseModel):
 
36
  analysis: str
37
 
38
+ # Inference function
39
+ def generate_response(input_text: str) -> str:
40
+ streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True)
41
+
42
+ messages = [
43
+ {"role": "system", "content": "You are an information security AI assistant specialized in analyzing security logs. Identify potential threats, suspicious IP addresses, unauthorized access attempts, and recommend actions based on the logs."},
44
+ {"role": "user", "content": f"Please analyze the following security logs and provide insights on any potential malicious activity:\n{input_text}"}
45
+ ]
46
+
47
+ input_ids = tokenizer.apply_chat_template(
48
+ messages,
49
+ tokenize=True,
50
+ add_generation_prompt=True,
51
+ return_tensors="pt",
52
+ ).to(model.device)
53
+
54
+ # Generate response with the model
55
+ outputs = model.generate(
56
+ input_ids,
57
+ streamer=streamer,
58
+ max_new_tokens=512, # Limit max tokens for faster response
59
+ num_beams=1,
60
+ do_sample=True,
61
+ temperature=0.1,
62
+ top_p=0.95,
63
+ top_k=10
64
+ )
65
+
66
+ # Extract and return generated text
67
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+ return response_text
69
+
70
  # Define the route for security log analysis
71
+ @app.post("/analyze_security_logs", response_model=SecurityAnalysisResponse)
72
+ async def analyze_security_logs(request: SecurityLogRequest):
 
73
  try:
74
+ # Run inference
75
+ analysis_text = generate_response(request.log_data)
76
+ return SecurityAnalysisResponse(analysis=analysis_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
  raise HTTPException(status_code=500, detail=str(e))
79