Coverage for src/ollamapy/ollama_client.py: 55%
139 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 12:29 -0400
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 12:29 -0400
1"""Enhanced Ollama API client with model context size support"""
3import json
4import logging
5import re
6import requests
7from typing import Dict, List, Optional, Generator, Any
9logger = logging.getLogger(__name__)
12class OllamaClient:
13 """Enhanced Ollama API client with model context size support"""
15 def __init__(self, base_url: str = "http://localhost:11434"):
16 """Initialize the Ollama client.
18 Args:
19 base_url: The base URL for the Ollama API server
20 """
21 self.base_url = base_url.rstrip("/")
22 self.session = requests.Session()
23 self._model_cache: Dict[str, int] = {}
25 def is_available(self) -> bool:
26 """Check if Ollama server is running and accessible."""
27 try:
28 response = self.session.get(f"{self.base_url}/api/tags", timeout=5)
29 return response.status_code == 200
30 except requests.exceptions.RequestException:
31 return False
33 def list_models(self) -> List[str]:
34 """Get list of available models."""
35 try:
36 response = self.session.get(f"{self.base_url}/api/tags")
37 response.raise_for_status()
38 data = response.json()
39 return [model["name"] for model in data.get("models", [])]
40 except requests.exceptions.RequestException:
41 return []
43 def get_model_context_size(self, model: str) -> int:
44 """Get the context window size for a model"""
45 if model in self._model_cache:
46 return self._model_cache[model]
48 try:
49 response = self.session.post(
50 f"{self.base_url}/api/show", json={"name": model}
51 )
52 response.raise_for_status()
53 data = response.json()
55 # Try to extract context size from model info
56 context_size = self._get_default_context_size(model)
57 if "modelfile" in data:
58 # Look for context size in modelfile
59 match = re.search(r'num_ctx["\s]+(\d+)', data["modelfile"])
60 if match:
61 context_size = int(match.group(1))
63 self._model_cache[model] = context_size
64 return context_size
65 except:
66 return 4096 # Default context size
68 def _get_default_context_size(self, model: str) -> int:
69 """Get default context size based on model name."""
70 model_lower = model.lower()
72 # Known context sizes for popular models
73 if "gemma3:4b" in model_lower or "gemma2-2b" in model_lower:
74 return 128000
75 elif "llama3.2:3b" in model_lower or "llama3.2:1b" in model_lower:
76 return 128000
77 elif "llama3.1" in model_lower:
78 return 128000
79 elif "llama3:8b" in model_lower or "llama3:7b" in model_lower:
80 return 8192
81 elif "mistral" in model_lower:
82 return 32768
83 elif "codellama" in model_lower:
84 return 16384
85 else:
86 return 4096 # Conservative default
88 def estimate_tokens(self, text: str) -> int:
89 """Estimate token count for text (rough approximation).
91 This uses a simple heuristic: ~4 characters per token on average.
92 For more accurate counting, we could use tiktoken or similar, but
93 this approximation is sufficient for context monitoring.
94 """
95 if not text:
96 return 0
98 # Simple estimation: 1 token ≈ 4 characters (varies by model/language)
99 char_count = len(text)
100 estimated_tokens = char_count // 3.5 # Slightly more conservative
102 # Add some tokens for special tokens and formatting
103 return int(estimated_tokens) + 10
105 def count_prompt_tokens(
106 self, prompt: str, system: Optional[str] = None, context: Optional[str] = None
107 ) -> int:
108 """Count approximate tokens for a complete prompt.
110 Args:
111 prompt: The main prompt text
112 system: Optional system message
113 context: Optional additional context
115 Returns:
116 Estimated total token count
117 """
118 total_tokens = 0
120 if system:
121 total_tokens += self.estimate_tokens(system)
123 if context:
124 total_tokens += self.estimate_tokens(context)
126 total_tokens += self.estimate_tokens(prompt)
128 # Add some buffer for formatting and special tokens
129 return total_tokens + 20
131 def get_context_usage(
132 self,
133 model: str,
134 prompt: str,
135 system: Optional[str] = None,
136 context: Optional[str] = None,
137 ) -> Dict[str, Any]:
138 """Get context window usage information.
140 Args:
141 model: The model name
142 prompt: The prompt text
143 system: Optional system message
144 context: Optional additional context
146 Returns:
147 Dict with usage information
148 """
149 max_context = self.get_model_context_size(model)
150 used_tokens = self.count_prompt_tokens(prompt, system, context)
152 # Reserve some tokens for the response (typically 20-30% of context)
153 response_reserve = int(max_context * 0.25) # Reserve 25% for response
154 available_for_prompt = max_context - response_reserve
156 usage_percent = (used_tokens / available_for_prompt) * 100
158 return {
159 "model": model,
160 "max_context": max_context,
161 "used_tokens": used_tokens,
162 "available_tokens": available_for_prompt,
163 "reserved_for_response": response_reserve,
164 "usage_percent": usage_percent,
165 "is_over_limit": used_tokens > available_for_prompt,
166 }
168 def print_context_usage(
169 self,
170 model: str,
171 prompt: str,
172 system: Optional[str] = None,
173 context: Optional[str] = None,
174 ) -> None:
175 """Print context usage information."""
176 usage = self.get_context_usage(model, prompt, system, context)
178 # Color coding for usage levels
179 if usage["usage_percent"] < 50:
180 color = "\033[92m" # Green
181 status = "🟢"
182 elif usage["usage_percent"] < 80:
183 color = "\033[93m" # Yellow
184 status = "🟡"
185 elif usage["usage_percent"] < 100:
186 color = "\033[91m" # Red
187 status = "🟠"
188 else:
189 color = "\033[91m" # Red
190 status = "🔴"
192 reset = "\033[0m" # Reset color
194 print(
195 f"{status} Context: {color}{usage['usage_percent']:.1f}%{reset} ({usage['used_tokens']}/{usage['available_tokens']} tokens, {usage['model']})"
196 )
198 if usage["is_over_limit"]:
199 print(f"⚠️ Warning: Prompt exceeds recommended context limit!")
201 def generate(
202 self,
203 model: str,
204 prompt: str,
205 system: Optional[str] = None,
206 show_context: bool = True,
207 ) -> str:
208 """Generate a response from the model with context monitoring."""
209 try:
210 # Show context usage if requested
211 if show_context:
212 self.print_context_usage(model, prompt, system)
214 payload = {"model": model, "prompt": prompt, "stream": False}
215 if system:
216 payload["system"] = system
218 response = self.session.post(
219 f"{self.base_url}/api/generate", json=payload, timeout=60
220 )
221 response.raise_for_status()
222 return response.json()["response"]
223 except requests.exceptions.RequestException as e:
224 logger.error(f"Generation failed: {e}")
225 return ""
227 def pull_model(self, model: str) -> bool:
228 """Pull a model if it's not available locally."""
229 try:
230 response = self.session.post(
231 f"{self.base_url}/api/pull", json={"name": model}, stream=True
232 )
233 response.raise_for_status()
235 for line in response.iter_lines():
236 if line:
237 data = json.loads(line)
238 if "status" in data:
239 print(f"\r{data['status']}", end="", flush=True)
240 if data.get("status") == "success":
241 print() # New line after completion
242 return True
243 return True
244 except requests.exceptions.RequestException as e:
245 print(f"Error pulling model: {e}")
246 return False
248 def chat_stream(
249 self, model: str, messages: List[Dict[str, str]], system: Optional[str] = None
250 ) -> Generator[str, None, None]:
251 """Stream chat responses from Ollama.
253 Args:
254 model: The model to use for chat
255 messages: List of message dicts with 'role' and 'content'
256 system: Optional system message
258 Yields:
259 Response chunks as strings
260 """
261 payload = {"model": model, "messages": messages, "stream": True}
263 if system:
264 payload["system"] = system
266 try:
267 response = self.session.post(
268 f"{self.base_url}/api/chat", json=payload, stream=True
269 )
270 response.raise_for_status()
272 for line in response.iter_lines():
273 if line:
274 data = json.loads(line)
275 if "message" in data and "content" in data["message"]:
276 yield data["message"]["content"]
277 if data.get("done", False):
278 break
280 except requests.exceptions.RequestException as e:
281 yield f"Error: {e}"