File size: 9,135 Bytes
4cfc34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import streamlit as st
import chromadb
from chromadb.utils import embedding_functions
import groq
from typing import Dict
import os

class CourseAdvisor:
    def __init__(self, db_path: str = "./chroma_db"):
        """Initialize the course advisor with existing ChromaDB database."""
        # Initialize persistent client with path
        self.chroma_client = chromadb.PersistentClient(path=db_path)
        
        # Initialize embedding function
        self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name="jinaai/jina-embeddings-v2-base-en"
        )
        
        # Get existing collection
        self.collection = self.chroma_client.get_collection(
            name="courses",
            embedding_function=self.embedding_function
        )
        
    def query_courses(self, query_text: str, chat_history: str, api_key: str, n_results: int = 3) -> Dict:
        """Query the vector database and get course recommendations."""
        # Initialize Groq client with provided API key
        groq_client = groq.Groq(api_key=api_key)
        
        try:
            # Get relevant documents from vector DB
            results = self.collection.query(
                query_texts=[query_text],
                n_results=min(n_results, self.collection.count()),
                include=['documents', 'metadatas']
            )

            # Prepare context from retrieved documents
            docs_context = "\n\n".join(results['documents'][0])
            
        except Exception as e:
            st.error(f"Error querying database: {str(e)}")
            return {
                'llm_response': "I encountered an error while searching the course database. Please try again.",
                'retrieved_courses': []
            }
        
        # Create prompt with chat history
        prompt = f"""Previous conversation:
{chat_history}

Current user query: {query_text}

Relevant course information:
{docs_context}

Please provide course recommendations based on the entire conversation context. Format your response as:
1. Understanding of the user's needs (based on conversation history)
2. Overall recommendation with reasoning
3. Specific benefits of each recommended course
4. Learning path suggestion (if applicable)
5. Any prerequisites or important notes"""

        try:
            # Get response from Groq
            completion = groq_client.chat.completions.create(
                messages=[
                    {"role": "system", "content": "You are a helpful course advisor who provides detailed, relevant course recommendations based on the user's needs and conversation history. Keep responses clear and well-structured."},
                    {"role": "user", "content": prompt}
                ],
                model="mixtral-8x7b-32768",
                temperature=0.7,
            )

            return {
                'llm_response': completion.choices[0].message.content,
                'retrieved_courses': results['metadatas'][0]
            }
            
        except Exception as e:
            st.error(f"Error with Groq API: {str(e)}")
            return {
                'llm_response': "I encountered an error while generating recommendations. Please check your API key and try again.",
                'retrieved_courses': []
            }

def initialize_session_state():
    """Initialize session state variables."""
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    if 'course_advisor' not in st.session_state:
        st.session_state.course_advisor = CourseAdvisor()
    if 'api_key' not in st.session_state:
        st.session_state.api_key = ""

def get_chat_history() -> str:
    """Format chat history for LLM context."""
    history = []
    for message in st.session_state.messages[-5:]:  # Only use last 5 messages for context
        role = message["role"]
        content = message["content"]
        history.append(f"{role}: {content}")
    return "\n".join(history)

def display_course_card(course: Dict):
    """Display a single course recommendation in a card format."""
    with st.container():
        # Add a light background and padding
        with st.container():
            st.markdown("""
                <style>
                .course-card {
                    background-color: #f8f9fa;
                    padding: 1rem;
                    border-radius: 0.5rem;
                    margin-bottom: 1rem;
                }
                </style>
            """, unsafe_allow_html=True)
            
            with st.container():
                st.markdown('<div class="course-card">', unsafe_allow_html=True)
                
                # Course title
                st.markdown(f"### {course['title']}")
                
                col1, col2 = st.columns(2)
                
                with col1:
                    # Handle categories - convert to list if string
                    categories = course.get('categories', 'N/A')
                    if isinstance(categories, str):
                        # Split by comma if it's a comma-separated string
                        categories = [cat.strip() for cat in categories.split(',')]
                    elif not isinstance(categories, list):
                        categories = [str(categories)]
                    
                    # Display categories as bullet points if multiple
                    if len(categories) > 1:
                        st.markdown("**Categories:**")
                        for category in categories:
                            st.markdown(f"- {category}")
                    else:
                        st.markdown(f"**Category:** {categories[0]}")
                    
                    st.markdown(f"**Lessons:** {course.get('lessons', 'N/A')}")
                
                with col2:
                    st.markdown(f"**Price:** {course.get('price', 'N/A')}")
                    if 'url' in course:
                        st.markdown(f"**[Visit Course]({course['url']})**")
                
                st.markdown('</div>', unsafe_allow_html=True)
            
        st.markdown("---")

def main():
    st.set_page_config(
        page_title="Course Recommender",
        page_icon="πŸ“š",
        layout="wide"
    )
    
    st.title("πŸ“š AI Course Recommender")
    
    # Initialize session state
    initialize_session_state()
    
    # Display collection info
    collection = st.session_state.course_advisor.collection
    st.sidebar.info(f"Connected to database with {collection.count()} courses")
    
    # Sidebar
    with st.sidebar:
        st.header("Settings")
        
        # API key input
        api_key = st.text_input("Enter Groq API Key", 
                               type="password",
                               value=st.session_state.api_key)
        if api_key != st.session_state.api_key:
            st.session_state.api_key = api_key
        
        # Clear chat button
        if st.button("Clear Chat History"):
            st.session_state.messages = []
    
    # Main chat interface
    st.header("Chat with AI Course Advisor")
    
    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    # Chat input
    if prompt := st.chat_input("What would you like to learn?"):
        # Check if API key is provided
        if not st.session_state.api_key:
            st.error("Please enter your Groq API key in the sidebar.")
            return
            
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": prompt})
        
        with st.chat_message("user"):
            st.markdown(prompt)
        
        # Get AI response
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                # Get formatted chat history
                chat_history = get_chat_history()
                
                # Query courses with chat history
                response = st.session_state.course_advisor.query_courses(
                    prompt,
                    chat_history,
                    st.session_state.api_key
                )
                
                # Display AI recommendation
                st.markdown(response['llm_response'])
                
                # Display course cards if any courses were retrieved
                if response['retrieved_courses']:
                    st.markdown("### πŸ“‹ Recommended Courses")
                    for course in response['retrieved_courses']:
                        display_course_card(course)
        
        # Add assistant response to chat history
        st.session_state.messages.append({
            "role": "assistant",
            "content": response['llm_response'] + "\n\n" + "### Recommended Courses\n" +
                      "\n".join([f"- {course['title']}" for course in response['retrieved_courses']])
        })

if __name__ == "__main__":
    main()