Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve Ollama client detection and error handling (#373) #672

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ flask
flask-cors
toml
urllib3
requests
requests>=2.31.0
colorama
fastlogging
Jinja2
Expand All @@ -12,7 +12,7 @@ pdfminer.six
playwright
pytest-playwright
tiktoken
ollama
ollama>=0.1.6
openai
anthropic
google-generativeai
Expand All @@ -31,3 +31,5 @@ orjson
gevent
gevent-websocket
curl_cffi
pytest>=7.4.0
pytest-mock>=3.12.0
105 changes: 90 additions & 15 deletions src/llm/ollama_client.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,100 @@
import os
import time
import requests
from typing import Optional, List, Dict, Any
from urllib.parse import urlparse

import ollama
from src.logger import Logger
from src.config import Config

log = Logger()


class Ollama:
def __init__(self):
try:
self.client = ollama.Client(Config().get_ollama_api_endpoint())
self.models = self.client.list()["models"]
log.info("Ollama available")
except:
self.client = None
log.warning("Ollama not available")
log.warning("run ollama server to use ollama models otherwise use API models")
"""Initialize Ollama client with retry logic and proper error handling."""
self.host = os.getenv("OLLAMA_HOST", Config().get_ollama_api_endpoint())
self.client = None
self.models = []
self._initialize_client()

def _initialize_client(self, max_retries: int = 3, initial_delay: float = 1.0) -> None:
"""Initialize Ollama client with retry logic.

Args:
max_retries: Maximum number of connection attempts
initial_delay: Initial delay between retries in seconds
"""
delay = initial_delay
for attempt in range(max_retries):
try:
# Validate URL format
parsed_url = urlparse(self.host)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid Ollama server URL: {self.host}")

# Test server connection
response = requests.get(f"{self.host}/api/version")
if response.status_code != 200:
raise ConnectionError(f"Ollama server returned status {response.status_code}")

# Initialize client and fetch models
self.client = ollama.Client(self.host)
self.models = self.client.list()["models"]
log.info(f"Ollama available at {self.host}")
log.info(f"Found {len(self.models)} models: {[m['name'] for m in self.models]}")
return

except requests.exceptions.ConnectionError as e:
log.warning(f"Connection failed to Ollama server at {self.host}")
log.warning(f"Error: {str(e)}")

except ValueError as e:
log.error(f"Configuration error: {str(e)}")
return

except Exception as e:
log.warning(f"Failed to initialize Ollama client: {str(e)}")

if attempt < max_retries - 1:
log.info(f"Retrying in {delay:.1f} seconds...")
time.sleep(delay)
delay *= 2 # Exponential backoff
else:
log.warning("Max retries reached. Please ensure Ollama server is running")
log.warning("Run 'ollama serve' to start the server")
log.warning("Or set OLLAMA_HOST environment variable to correct server URL")

self.client = None
self.models = []

def inference(self, model_id: str, prompt: str) -> str:
response = self.client.generate(
model=model_id,
prompt=prompt.strip(),
options={"temperature": 0}
)
return response['response']
"""Run inference using specified model.

Args:
model_id: Name of the Ollama model to use
prompt: Input prompt for the model

Returns:
Model response text

Raises:
RuntimeError: If client is not initialized or model is not found
"""
if not self.client:
raise RuntimeError("Ollama client not initialized. Please check server connection.")

if not any(m['name'] == model_id for m in self.models):
raise RuntimeError(f"Model {model_id} not found in available models: {[m['name'] for m in self.models]}")

try:
response = self.client.generate(
model=model_id,
prompt=prompt.strip(),
options={"temperature": 0}
)
return response['response']

except Exception as e:
log.error(f"Inference failed for model {model_id}: {str(e)}")
raise RuntimeError(f"Failed to get response from Ollama: {str(e)}")
129 changes: 129 additions & 0 deletions tests/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
import os
import requests
from unittest.mock import patch, MagicMock
from src.llm.ollama_client import Ollama
from src.config import Config

def test_ollama_client_initialization():
"""Test Ollama client initialization with default config"""
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client:
mock_get.return_value = MagicMock(status_code=200)
mock_client.return_value.list.return_value = {"models": []}

client = Ollama()
assert client.host == Config().get_ollama_api_endpoint()
assert client.client is not None
assert isinstance(client.models, list)

def test_ollama_client_initialization_with_env():
"""Test Ollama client initialization with environment variable"""
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client, \
patch.dict(os.environ, {'OLLAMA_HOST': 'http://ollama-service:11434'}):
mock_get.return_value = MagicMock(status_code=200)
mock_client.return_value.list.return_value = {"models": []}

client = Ollama()
assert client.host == "http://ollama-service:11434"
assert client.client is not None

def test_ollama_client_connection_retry():
"""Test Ollama client connection retry logic"""
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client, \
patch('time.sleep') as mock_sleep:
# Simulate first two failures, then success
mock_get.side_effect = [
requests.exceptions.ConnectionError(),
requests.exceptions.ConnectionError(),
MagicMock(status_code=200)
]
mock_client.return_value.list.return_value = {"models": []}

client = Ollama()
assert client.client is not None
assert mock_get.call_count == 3
assert mock_sleep.call_count == 2

def test_ollama_client_invalid_url():
"""Test Ollama client with invalid URL"""
with patch.dict(os.environ, {'OLLAMA_HOST': 'invalid-url'}):
client = Ollama()
assert client.client is None
assert len(client.models) == 0

def test_ollama_client_models_list():
"""Test Ollama client models list retrieval"""
mock_models = {
"models": [
{"name": "llama2"},
{"name": "codellama"}
]
}
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client:
mock_get.return_value = MagicMock(status_code=200)
mock_client.return_value.list.return_value = mock_models

client = Ollama()
assert len(client.models) == 2
assert client.models[0]["name"] == "llama2"
assert client.models[1]["name"] == "codellama"

def test_ollama_client_inference():
"""Test Ollama client inference"""
mock_models = {
"models": [
{"name": "llama2"}
]
}
mock_response = {
"response": "Test response"
}
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client:
mock_get.return_value = MagicMock(status_code=200)
mock_client.return_value.list.return_value = mock_models
mock_client.return_value.generate.return_value = mock_response

client = Ollama()
response = client.inference("llama2", "Test prompt")
assert response == "Test response"
mock_client.return_value.generate.assert_called_once()

def test_ollama_client_inference_invalid_model():
"""Test Ollama client inference with invalid model"""
mock_models = {
"models": [
{"name": "llama2"}
]
}
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client:
mock_get.return_value = MagicMock(status_code=200)
mock_client.return_value.list.return_value = mock_models

client = Ollama()
with pytest.raises(RuntimeError) as exc_info:
client.inference("invalid-model", "Test prompt")
assert "Model invalid-model not found" in str(exc_info.value)

def test_ollama_client_inference_server_error():
"""Test Ollama client inference with server error"""
mock_models = {
"models": [
{"name": "llama2"}
]
}
with patch('requests.get') as mock_get, \
patch('ollama.Client') as mock_client:
mock_get.return_value = MagicMock(status_code=200)
mock_client.return_value.list.return_value = mock_models
mock_client.return_value.generate.side_effect = Exception("Server error")

client = Ollama()
with pytest.raises(RuntimeError) as exc_info:
client.inference("llama2", "Test prompt")
assert "Failed to get response from Ollama" in str(exc_info.value)