Coverage for src/ollamapy/ollama_client.py: 55%

139 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 12:29 -0400

1"""Enhanced Ollama API client with model context size support""" 

2 

3import json 

4import logging 

5import re 

6import requests 

7from typing import Dict, List, Optional, Generator, Any 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12class OllamaClient: 

13 """Enhanced Ollama API client with model context size support""" 

14 

15 def __init__(self, base_url: str = "http://localhost:11434"): 

16 """Initialize the Ollama client. 

17 

18 Args: 

19 base_url: The base URL for the Ollama API server 

20 """ 

21 self.base_url = base_url.rstrip("/") 

22 self.session = requests.Session() 

23 self._model_cache: Dict[str, int] = {} 

24 

25 def is_available(self) -> bool: 

26 """Check if Ollama server is running and accessible.""" 

27 try: 

28 response = self.session.get(f"{self.base_url}/api/tags", timeout=5) 

29 return response.status_code == 200 

30 except requests.exceptions.RequestException: 

31 return False 

32 

33 def list_models(self) -> List[str]: 

34 """Get list of available models.""" 

35 try: 

36 response = self.session.get(f"{self.base_url}/api/tags") 

37 response.raise_for_status() 

38 data = response.json() 

39 return [model["name"] for model in data.get("models", [])] 

40 except requests.exceptions.RequestException: 

41 return [] 

42 

43 def get_model_context_size(self, model: str) -> int: 

44 """Get the context window size for a model""" 

45 if model in self._model_cache: 

46 return self._model_cache[model] 

47 

48 try: 

49 response = self.session.post( 

50 f"{self.base_url}/api/show", json={"name": model} 

51 ) 

52 response.raise_for_status() 

53 data = response.json() 

54 

55 # Try to extract context size from model info 

56 context_size = self._get_default_context_size(model) 

57 if "modelfile" in data: 

58 # Look for context size in modelfile 

59 match = re.search(r'num_ctx["\s]+(\d+)', data["modelfile"]) 

60 if match: 

61 context_size = int(match.group(1)) 

62 

63 self._model_cache[model] = context_size 

64 return context_size 

65 except: 

66 return 4096 # Default context size 

67 

68 def _get_default_context_size(self, model: str) -> int: 

69 """Get default context size based on model name.""" 

70 model_lower = model.lower() 

71 

72 # Known context sizes for popular models 

73 if "gemma3:4b" in model_lower or "gemma2-2b" in model_lower: 

74 return 128000 

75 elif "llama3.2:3b" in model_lower or "llama3.2:1b" in model_lower: 

76 return 128000 

77 elif "llama3.1" in model_lower: 

78 return 128000 

79 elif "llama3:8b" in model_lower or "llama3:7b" in model_lower: 

80 return 8192 

81 elif "mistral" in model_lower: 

82 return 32768 

83 elif "codellama" in model_lower: 

84 return 16384 

85 else: 

86 return 4096 # Conservative default 

87 

88 def estimate_tokens(self, text: str) -> int: 

89 """Estimate token count for text (rough approximation). 

90 

91 This uses a simple heuristic: ~4 characters per token on average. 

92 For more accurate counting, we could use tiktoken or similar, but 

93 this approximation is sufficient for context monitoring. 

94 """ 

95 if not text: 

96 return 0 

97 

98 # Simple estimation: 1 token ≈ 4 characters (varies by model/language) 

99 char_count = len(text) 

100 estimated_tokens = char_count // 3.5 # Slightly more conservative 

101 

102 # Add some tokens for special tokens and formatting 

103 return int(estimated_tokens) + 10 

104 

105 def count_prompt_tokens( 

106 self, prompt: str, system: Optional[str] = None, context: Optional[str] = None 

107 ) -> int: 

108 """Count approximate tokens for a complete prompt. 

109 

110 Args: 

111 prompt: The main prompt text 

112 system: Optional system message 

113 context: Optional additional context 

114 

115 Returns: 

116 Estimated total token count 

117 """ 

118 total_tokens = 0 

119 

120 if system: 

121 total_tokens += self.estimate_tokens(system) 

122 

123 if context: 

124 total_tokens += self.estimate_tokens(context) 

125 

126 total_tokens += self.estimate_tokens(prompt) 

127 

128 # Add some buffer for formatting and special tokens 

129 return total_tokens + 20 

130 

131 def get_context_usage( 

132 self, 

133 model: str, 

134 prompt: str, 

135 system: Optional[str] = None, 

136 context: Optional[str] = None, 

137 ) -> Dict[str, Any]: 

138 """Get context window usage information. 

139 

140 Args: 

141 model: The model name 

142 prompt: The prompt text 

143 system: Optional system message 

144 context: Optional additional context 

145 

146 Returns: 

147 Dict with usage information 

148 """ 

149 max_context = self.get_model_context_size(model) 

150 used_tokens = self.count_prompt_tokens(prompt, system, context) 

151 

152 # Reserve some tokens for the response (typically 20-30% of context) 

153 response_reserve = int(max_context * 0.25) # Reserve 25% for response 

154 available_for_prompt = max_context - response_reserve 

155 

156 usage_percent = (used_tokens / available_for_prompt) * 100 

157 

158 return { 

159 "model": model, 

160 "max_context": max_context, 

161 "used_tokens": used_tokens, 

162 "available_tokens": available_for_prompt, 

163 "reserved_for_response": response_reserve, 

164 "usage_percent": usage_percent, 

165 "is_over_limit": used_tokens > available_for_prompt, 

166 } 

167 

168 def print_context_usage( 

169 self, 

170 model: str, 

171 prompt: str, 

172 system: Optional[str] = None, 

173 context: Optional[str] = None, 

174 ) -> None: 

175 """Print context usage information.""" 

176 usage = self.get_context_usage(model, prompt, system, context) 

177 

178 # Color coding for usage levels 

179 if usage["usage_percent"] < 50: 

180 color = "\033[92m" # Green 

181 status = "🟢" 

182 elif usage["usage_percent"] < 80: 

183 color = "\033[93m" # Yellow 

184 status = "🟡" 

185 elif usage["usage_percent"] < 100: 

186 color = "\033[91m" # Red 

187 status = "🟠" 

188 else: 

189 color = "\033[91m" # Red 

190 status = "🔴" 

191 

192 reset = "\033[0m" # Reset color 

193 

194 print( 

195 f"{status} Context: {color}{usage['usage_percent']:.1f}%{reset} ({usage['used_tokens']}/{usage['available_tokens']} tokens, {usage['model']})" 

196 ) 

197 

198 if usage["is_over_limit"]: 

199 print(f"⚠️ Warning: Prompt exceeds recommended context limit!") 

200 

201 def generate( 

202 self, 

203 model: str, 

204 prompt: str, 

205 system: Optional[str] = None, 

206 show_context: bool = True, 

207 ) -> str: 

208 """Generate a response from the model with context monitoring.""" 

209 try: 

210 # Show context usage if requested 

211 if show_context: 

212 self.print_context_usage(model, prompt, system) 

213 

214 payload = {"model": model, "prompt": prompt, "stream": False} 

215 if system: 

216 payload["system"] = system 

217 

218 response = self.session.post( 

219 f"{self.base_url}/api/generate", json=payload, timeout=60 

220 ) 

221 response.raise_for_status() 

222 return response.json()["response"] 

223 except requests.exceptions.RequestException as e: 

224 logger.error(f"Generation failed: {e}") 

225 return "" 

226 

227 def pull_model(self, model: str) -> bool: 

228 """Pull a model if it's not available locally.""" 

229 try: 

230 response = self.session.post( 

231 f"{self.base_url}/api/pull", json={"name": model}, stream=True 

232 ) 

233 response.raise_for_status() 

234 

235 for line in response.iter_lines(): 

236 if line: 

237 data = json.loads(line) 

238 if "status" in data: 

239 print(f"\r{data['status']}", end="", flush=True) 

240 if data.get("status") == "success": 

241 print() # New line after completion 

242 return True 

243 return True 

244 except requests.exceptions.RequestException as e: 

245 print(f"Error pulling model: {e}") 

246 return False 

247 

248 def chat_stream( 

249 self, model: str, messages: List[Dict[str, str]], system: Optional[str] = None 

250 ) -> Generator[str, None, None]: 

251 """Stream chat responses from Ollama. 

252 

253 Args: 

254 model: The model to use for chat 

255 messages: List of message dicts with 'role' and 'content' 

256 system: Optional system message 

257 

258 Yields: 

259 Response chunks as strings 

260 """ 

261 payload = {"model": model, "messages": messages, "stream": True} 

262 

263 if system: 

264 payload["system"] = system 

265 

266 try: 

267 response = self.session.post( 

268 f"{self.base_url}/api/chat", json=payload, stream=True 

269 ) 

270 response.raise_for_status() 

271 

272 for line in response.iter_lines(): 

273 if line: 

274 data = json.loads(line) 

275 if "message" in data and "content" in data["message"]: 

276 yield data["message"]["content"] 

277 if data.get("done", False): 

278 break 

279 

280 except requests.exceptions.RequestException as e: 

281 yield f"Error: {e}"