File size: 4,800 Bytes
f7c4af0
 
9cc7e25
 
 
 
 
 
22cfb6e
6b16c34
a72e07a
6b16c34
a72e07a
 
8838db8
a9d0935
a72e07a
282bf68
8838db8
a72e07a
 
 
 
 
 
 
5d707d4
f7c4af0
 
 
 
 
 
 
 
22cfb6e
9cc7e25
 
 
a623cc7
 
 
 
 
9cc7e25
 
 
a72e07a
85d7111
a72e07a
5ecd97e
 
a72e07a
ac63cbd
5ecd97e
 
a72e07a
5ecd97e
a72e07a
 
5ecd97e
 
 
 
a72e07a
5ecd97e
 
 
 
 
 
 
 
9cc7e25
 
 
 
 
5ecd97e
9cc7e25
8324d73
bc45f67
a72e07a
8324d73
 
a72e07a
 
 
 
 
8324d73
a72e07a
8324d73
 
 
 
 
a72e07a
 
8324d73
 
 
 
 
 
 
 
 
 
 
 
 
 
ac63cbd
 
22cfb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
ac63cbd
66a4894
ac63cbd
22cfb6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import shutil
import os
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from sentence_transformers import SentenceTransformer
import chromadb
from datasets import load_dataset
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments

# Set environment variables to address warnings
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'


torch.random.manual_seed(0)
model_name = "microsoft/Phi-3-mini-4k-instruct-gguf"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Function to clear the cache
def clear_cache(model_name):
    cache_dir = os.path.expanduser(f'~/.cache/torch/sentence_transformers/{model_name.replace("/", "_")}')
    if os.path.exists(cache_dir):
        shutil.rmtree(cache_dir)
        print(f"Cleared cache directory: {cache_dir}")
    else:
        print(f"No cache directory found for: {cache_dir}")

# Embedding vector
class VectorStore:
    def __init__(self, collection_name):
        try:
            self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
        self.chroma_client = chromadb.Client()
        self.collection = self.chroma_client.create_collection(name=collection_name)

    def populate_vectors(self, dataset, batch_size=20):
        dataset = load_dataset('Thefoodprocessor/recipe_new_with_features_full', split='train')
        dataset = dataset.select(range(1500))

        texts = []
        i = 0
        for example in dataset:
            title = example['title_cleaned']
            recipe = example['recipe_new']
            meal_type = example['meal_type']
            allergy = example['allergy_type']
            ingredients_alternative = example['ingredients_alternatives']
            text = f"{title} {recipe} {meal_type} {allergy} {ingredients_alternative}"
            texts.append(text)
            if (i + 1) % batch_size == 0:
                self._process_batch(texts, i)
                texts = []
            i += 1
        if texts:
            self._process_batch(texts, i)

    def _process_batch(self, texts, batch_start_idx):
        embeddings = self.embedding_model.encode(texts, batch_size=len(texts)).tolist()
        for j, embedding in enumerate(embeddings):
            self.collection.add(embeddings=[embedding], documents=[texts[j]], ids=[str(batch_start_idx + j)])

    def search_context(self, query, n_results=1):
        query_embeddings = self.embedding_model.encode(query).tolist()
        return self.collection.query(query_embeddings=query_embeddings, n_results=n_results)

vector_store = VectorStore("embedding_vector")
vector_store.populate_vectors(dataset=None)

def fine_tune_model():
    dataset = load_dataset('Thefoodprocessor/recipe_new_with_features_full', split='train')
    dataset = dataset.select(range(1500))

    def tokenize_function(examples):
        return tokenizer(
            [" ".join([title, recipe]) for title, recipe in zip(examples['title_cleaned'], examples['recipe_new'])], 
            padding="max_length", 
            truncation=True
        )

    tokenized_datasets = dataset.map(tokenize_function, batched=True, batch_size=8)

    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=3,
        weight_decay=0.01,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets,
    )

    trainer.train()

fine_tune_model()

conversation_history = []

def chatbot_response(user_input):
    global conversation_history
    results = vector_store.search_context(user_input, n_results=1)
    context = results['documents'][0] if results['documents'] else ""
    conversation_history.append(f"User: {user_input}\nContext: {context[:150]}\nBot:")
    inputs = tokenizer("\n".join(conversation_history), return_tensors="pt")
    outputs = model.generate(**inputs, max_length=150, do_sample=True, temperature=0.7)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    conversation_history.append(response)
    return response

def chat(user_input):
    response = chatbot_response(user_input)
    return response

css = ".gradio-container {background: url(https://upload.wikimedia.org/wikipedia/commons/f/f5/Spring_Kitchen_Line-Up_%28Unsplash%29.jpg)}"
iface = gr.Interface(fn=chat, inputs="text", outputs="text", css=css)
iface.launch()