admincybers2 commited on
Commit
a9fd595
·
verified ·
1 Parent(s): 9b801f7

Create aitask.py

Browse files
Files changed (1) hide show
  1. aitask.py +150 -0
aitask.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from confluent_kafka import KafkaException, Producer
4
+ import json
5
+ import torch
6
+ from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+ from confluent_kafka.serialization import (
8
+ MessageField,
9
+ SerializationContext,
10
+ )
11
+ from unsloth import FastLanguageModel
12
+ from uuid import uuid4
13
+ import concurrent.futures
14
+
15
+ os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
16
+ hf_token = os.getenv("HF_TOKEN")
17
+
18
+ class MessageSend:
19
+ def __init__(self, username, title, level, detail=None):
20
+ self.username = username
21
+ self.title = title
22
+ self.level = level
23
+ self.detail = detail
24
+
25
+ def cover_message(msg):
26
+ """Return a dictionary representation of a User instance for serialization."""
27
+ return dict(
28
+ username=msg.username,
29
+ title=msg.title,
30
+ level=msg.level,
31
+ detail=msg.detail
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+ logging.basicConfig(level=logging.INFO)
36
+
37
+ class TooManyRequestsError(Exception):
38
+ def __init__(self, retry_after):
39
+ self.retry_after = retry_after
40
+
41
+ model, tokenizer = FastLanguageModel.from_pretrained(
42
+ model_name = "admincybers2/sentinal",
43
+ max_seq_length = 4096,
44
+ dtype = None,
45
+ load_in_4bit = True,
46
+ token=hf_token
47
+ )
48
+
49
+ # Enable native 2x faster inference
50
+ FastLanguageModel.for_inference(model)
51
+ vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}"
52
+
53
+
54
+ def extract_data(full_message):
55
+ try:
56
+ message = json.loads(full_message)
57
+ return message
58
+ except json.JSONDecodeError as e:
59
+ logger.error(f"Failed to extract data: {e}")
60
+ return None
61
+
62
+ def perform_ai_task(question):
63
+ prompt = vulnerable_prompt.format(question, "")
64
+ inputs = tokenizer([prompt], return_tensors="pt")
65
+ text_streamer = TextStreamer(tokenizer)
66
+
67
+ try:
68
+ model_output = model.generate(
69
+ **inputs,
70
+ streamer=text_streamer,
71
+ use_cache=True,
72
+ max_new_tokens=640,
73
+ temperature=0.5,
74
+ top_k=50,
75
+ top_p=0.9,
76
+ min_p=0.01,
77
+ typical_p=0.95,
78
+ repetition_penalty=1.2,
79
+ no_repeat_ngram_size=3,
80
+ )
81
+ generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True)
82
+ except RuntimeError as e:
83
+ error_message = str(e)
84
+ if "probability tensor contains either `inf`, `nan` or element < 0" in error_message:
85
+ logger.error("Encountered probability tensor error, skipping this task.")
86
+ return None
87
+ else:
88
+ logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.")
89
+
90
+ deduplicated_text = deduplicate_text(generated_text)
91
+ return {
92
+ "detail": deduplicated_text
93
+ }
94
+
95
+ def deduplicate_text(text):
96
+ sentences = text.split('. ')
97
+ seen_sentences = set()
98
+ deduplicated_sentences = []
99
+
100
+ for sentence in sentences:
101
+ if sentence not in seen_sentences:
102
+ seen_sentences.add(sentence)
103
+ deduplicated_sentences.append(sentence)
104
+
105
+ return '. '.join(deduplicated_sentences) + '.'
106
+
107
+ def delivery_report(err, msg):
108
+ if err is not None:
109
+ logger.error(f"Message delivery failed: {err}")
110
+ else:
111
+ logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]")
112
+
113
+ def handle_message(msg, producer, ensure_producer_connected, avro_serializer):
114
+ logger.info(f'Message value {msg}')
115
+ if msg:
116
+ ensure_producer_connected(producer)
117
+ try:
118
+ ai_results = perform_ai_task(msg['message_send'])
119
+ if ai_results is None:
120
+ logger.error("AI task skipped due to an error in model generation.")
121
+ return
122
+
123
+ detail = ai_results.get("detail", "No details available")
124
+
125
+ topic = "get_scan_message"
126
+
127
+ messagedict = cover_message(
128
+ MessageSend(
129
+ username=msg['username'],
130
+ title=msg['path'],
131
+ level='',
132
+ detail=detail
133
+ )
134
+ )
135
+
136
+ if messagedict:
137
+ byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE))
138
+ producer.produce(
139
+ topic,
140
+ value=byte_value,
141
+ headers={"correlation_id": str(uuid4())},
142
+ callback=delivery_report
143
+ )
144
+ producer.flush()
145
+ else:
146
+ logger.error("Message serialization failed; skipping production.")
147
+ except KafkaException as e:
148
+ logger.error(f"Kafka error producing message: {e}")
149
+ except Exception as e:
150
+ logger.error(f"Unhandled error in handle_message: {e}")