chenhaodev commited on
Commit
e021c17
·
1 Parent(s): 1bafebd

use llama-cpp-python instead of ollama

Browse files
Files changed (3) hide show
  1. Dockerfile +16 -64
  2. app.py +30 -40
  3. requirements.txt +1 -1
Dockerfile CHANGED
@@ -1,79 +1,31 @@
1
- # Use Ubuntu as base image
2
- FROM ubuntu:22.04
3
-
4
- # Prevent interactive prompts during package installation
5
- ENV DEBIAN_FRONTEND=noninteractive
6
 
7
  # Install system dependencies
8
- RUN apt-get update && apt-get install -y \
9
- python3 \
10
- python3-pip \
11
- curl \
12
- wget \
13
  git \
14
- net-tools \
15
  && rm -rf /var/lib/apt/lists/*
16
 
17
- # Install Ollama
18
- RUN curl -fsSL https://ollama.com/install.sh | sh
19
-
20
- # Set working directory
21
  WORKDIR /app
22
 
23
- # Copy requirements and install Python dependencies
24
  COPY requirements.txt .
25
- RUN pip3 install --no-cache-dir -r requirements.txt
 
 
 
 
 
 
26
 
27
  # Copy application code
28
  COPY . .
29
 
30
- # Create startup script with health checks and retries
31
- RUN echo '#!/bin/bash\n\
32
- \n\
33
- # Function to check if Ollama is responsive\n\
34
- check_ollama() {\n\
35
- curl -s http://localhost:11434/api/version &>/dev/null\n\
36
- }\n\
37
- \n\
38
- # Start Ollama server\n\
39
- ollama serve & \n\
40
- \n\
41
- # Wait for Ollama to be responsive (up to 60 seconds)\n\
42
- count=0\n\
43
- while ! check_ollama && [ $count -lt 60 ]; do\n\
44
- echo "Waiting for Ollama server to start..."\n\
45
- sleep 1\n\
46
- count=$((count + 1))\n\
47
- done\n\
48
- \n\
49
- if ! check_ollama; then\n\
50
- echo "Failed to start Ollama server"\n\
51
- exit 1\n\
52
- fi\n\
53
- \n\
54
- # Pull the model with retry logic\n\
55
- max_retries=3\n\
56
- retry_count=0\n\
57
- while [ $retry_count -lt $max_retries ]; do\n\
58
- if ollama pull deepseek-r1:1.5b; then\n\
59
- break\n\
60
- fi\n\
61
- echo "Failed to pull model, retrying..."\n\
62
- retry_count=$((retry_count + 1))\n\
63
- sleep 5\n\
64
- done\n\
65
- \n\
66
- if [ $retry_count -eq $max_retries ]; then\n\
67
- echo "Failed to pull model after $max_retries attempts"\n\
68
- exit 1\n\
69
- fi\n\
70
- \n\
71
- # Start the Gradio app\n\
72
- exec python3 -u app.py\n\
73
- ' > start.sh && chmod +x start.sh
74
-
75
- # Expose port for Gradio web interface
76
  EXPOSE 7860
77
 
78
  # Run the application
79
- ENTRYPOINT ["./start.sh"]
 
1
+ FROM python:3.10-slim
 
 
 
 
2
 
3
  # Install system dependencies
4
+ RUN apt-get update && \
5
+ apt-get install -y \
6
+ build-essential \
7
+ python3-dev \
 
8
  git \
9
+ wget \
10
  && rm -rf /var/lib/apt/lists/*
11
 
 
 
 
 
12
  WORKDIR /app
13
 
14
+ # Copy requirements first for better caching
15
  COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Create model directory
19
+ RUN mkdir -p /app/models
20
+
21
+ # Download the GGUF model (replace with your preferred Qwen GGUF model)
22
+ RUN wget -P /app/models https://huggingface.co/unsloth/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/blob/main/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf
23
 
24
  # Copy application code
25
  COPY . .
26
 
27
+ # Expose port for Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  EXPOSE 7860
29
 
30
  # Run the application
31
+ CMD ["python3", "app.py"]
app.py CHANGED
@@ -3,46 +3,33 @@ import jiwer
3
  import pandas as pd
4
  import logging
5
  from typing import List, Optional, Tuple, Dict
6
- from ollama import Client
7
- import re
8
  import os
9
- import time
10
- import requests
11
 
12
- # Set up logging configuration
13
  logging.basicConfig(
14
  level=logging.INFO,
15
  format='%(asctime)s - %(levelname)s - %(message)s',
16
  force=True,
17
- handlers=[
18
- logging.StreamHandler(),
19
- ]
20
  )
21
  logger = logging.getLogger(__name__)
22
 
23
- # Initialize Ollama client with retry logic
24
- def init_ollama_client(max_retries=5):
25
- client = None
26
- for i in range(max_retries):
27
- try:
28
- client = Client(host='http://localhost:11434')
29
- # Test the connection
30
- response = requests.get('http://localhost:11434/api/version')
31
- if response.status_code == 200:
32
- logger.info("Successfully connected to Ollama")
33
- return client
34
- except Exception as e:
35
- logger.warning(f"Attempt {i+1}/{max_retries} to connect to Ollama failed: {str(e)}")
36
- if i < max_retries - 1:
37
- time.sleep(2)
38
- raise Exception("Failed to initialize Ollama client")
39
 
40
- # Global client initialization
41
  try:
42
- client = init_ollama_client()
 
 
 
 
 
 
 
43
  except Exception as e:
44
- logger.error(f"Failed to initialize Ollama: {str(e)}")
45
- client = None
46
 
47
  def calculate_wer_metrics(
48
  hypothesis: str,
@@ -124,29 +111,30 @@ def calculate_wer_metrics(
124
  return measures
125
 
126
  def extract_medical_terms(text: str) -> List[str]:
127
- """Extract medical terms from text using Qwen model via Ollama."""
128
- if client is None:
129
- logger.error("Ollama client not initialized")
130
  return []
131
  prompt = f"""Extract all medical terms from the following text.
132
  Return only the medical terms as a comma-separated list.
133
  Text: {text}"""
134
 
135
  try:
136
- response = client.generate(
137
- model='deepseek-r1:1.5b',
138
- prompt=prompt,
139
- stream=False
 
 
140
  )
141
 
142
- response_text = response['response']
143
 
144
- # Remove the thinking process
145
  if '<think>' in response_text and '</think>' in response_text:
146
  medical_terms_text = response_text.split('</think>')[-1].strip()
147
  else:
148
  medical_terms_text = response_text
149
-
150
  medical_terms = [term.strip() for term in medical_terms_text.split(',')]
151
  return [term for term in medical_terms if term and not term.startswith('<') and not term.endswith('>')]
152
 
@@ -198,9 +186,12 @@ def process_inputs(
198
 
199
  try:
200
  # Extract medical terms
 
201
  reference_terms = extract_medical_terms(reference)
 
 
202
  hypothesis_terms = extract_medical_terms(hypothesis)
203
-
204
  # Calculate medical recall
205
  med_recall = calculate_medical_recall(hypothesis_terms, reference_terms)
206
 
@@ -332,7 +323,6 @@ def create_interface() -> gr.Blocks:
332
  inputs=[reference, hypothesis, normalize, words_to_filter],
333
  outputs=[metrics_output, error_output, explanation_output, error_msg_output]
334
  )
335
-
336
  return interface
337
 
338
  if __name__ == "__main__":
 
3
  import pandas as pd
4
  import logging
5
  from typing import List, Optional, Tuple, Dict
6
+ from llama_cpp import Llama
 
7
  import os
 
 
8
 
9
+ # Set up logging
10
  logging.basicConfig(
11
  level=logging.INFO,
12
  format='%(asctime)s - %(levelname)s - %(message)s',
13
  force=True,
14
+ handlers=[logging.StreamHandler()]
 
 
15
  )
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Initialize LLM
19
+ MODEL_PATH = "/app/models/DeepSeek-R1-Distill-Qwen-1.5B-Q4_K_M.gguf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
21
  try:
22
+ llm = Llama(
23
+ model_path=MODEL_PATH,
24
+ n_ctx=2048, # Context window
25
+ n_threads=4, # CPU threads
26
+ n_batch=512, # Batch size
27
+ verbose=False # Disable verbose output
28
+ )
29
+ logger.info("LLM initialized successfully")
30
  except Exception as e:
31
+ logger.error(f"Failed to initialize LLM: {str(e)}")
32
+ llm = None
33
 
34
  def calculate_wer_metrics(
35
  hypothesis: str,
 
111
  return measures
112
 
113
  def extract_medical_terms(text: str) -> List[str]:
114
+ """Extract medical terms from text using Qwen model."""
115
+ if llm is None:
116
+ logger.error("LLM not initialized")
117
  return []
118
  prompt = f"""Extract all medical terms from the following text.
119
  Return only the medical terms as a comma-separated list.
120
  Text: {text}"""
121
 
122
  try:
123
+ response = llm(
124
+ prompt,
125
+ max_tokens=256,
126
+ temperature=0.1,
127
+ stop=["Text:", "\n\n"],
128
+ echo=False
129
  )
130
 
131
+ response_text = response['choices'][0]['text'].strip()
132
 
133
+ # Remove thinking process if present
134
  if '<think>' in response_text and '</think>' in response_text:
135
  medical_terms_text = response_text.split('</think>')[-1].strip()
136
  else:
137
  medical_terms_text = response_text
 
138
  medical_terms = [term.strip() for term in medical_terms_text.split(',')]
139
  return [term for term in medical_terms if term and not term.startswith('<') and not term.endswith('>')]
140
 
 
186
 
187
  try:
188
  # Extract medical terms
189
+ logger.info("Extracting medical terms from reference text...")
190
  reference_terms = extract_medical_terms(reference)
191
+ logger.info(f"Reference terms extracted: {reference_terms}")
192
+ logger.info("Extracting medical terms from hypothesis text...")
193
  hypothesis_terms = extract_medical_terms(hypothesis)
194
+ logger.info(f"Hypothesis terms extracted: {hypothesis_terms}")
195
  # Calculate medical recall
196
  med_recall = calculate_medical_recall(hypothesis_terms, reference_terms)
197
 
 
323
  inputs=[reference, hypothesis, normalize, words_to_filter],
324
  outputs=[metrics_output, error_output, explanation_output, error_msg_output]
325
  )
 
326
  return interface
327
 
328
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  gradio==5.16.0
2
  jiwer==3.1.0
3
  pandas==2.2.0
4
- ollama==0.4.5
 
1
  gradio==5.16.0
2
  jiwer==3.1.0
3
  pandas==2.2.0
4
+ llama-cpp-python