Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve Gemini client error handling and add tests (#530) #674

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pytest-playwright
tiktoken
ollama
openai
anthropic
google-generativeai
anthropic>=0.8.0
google-generativeai>=0.3.0
sqlmodel
keybert
GitPython
Expand Down
70 changes: 48 additions & 22 deletions src/llm/gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,58 @@
from google.generativeai.types import HarmCategory, HarmBlockThreshold

from src.config import Config
from src.logger import Logger

logger = Logger()
config = Config()

class Gemini:
def __init__(self):
config = Config()
api_key = config.get_gemini_api_key()
genai.configure(api_key=api_key)
if not api_key:
error_msg = ("Gemini API key not found in configuration. "
"Please add your Gemini API key to config.toml under [API_KEYS] "
"section as GEMINI = 'your-api-key'")
logger.error(error_msg)
raise ValueError(error_msg)
try:
genai.configure(api_key=api_key)
logger.info("Successfully initialized Gemini client")
except Exception as e:
error_msg = f"Failed to configure Gemini client: {str(e)}"
logger.error(error_msg)
raise ValueError(error_msg)

def inference(self, model_id: str, prompt: str) -> str:
config = genai.GenerationConfig(temperature=0)
model = genai.GenerativeModel(model_id, generation_config=config)
# Set safety settings for the request
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
# You can adjust other categories as needed
}
response = model.generate_content(prompt, safety_settings=safety_settings)
try:
# Check if the response contains text
return response.text
except ValueError:
# If the response doesn't contain text, check if the prompt was blocked
print("Prompt feedback:", response.prompt_feedback)
# Also check the finish reason to see if the response was blocked
print("Finish reason:", response.candidates[0].finish_reason)
# If the finish reason was SAFETY, the safety ratings have more details
print("Safety ratings:", response.candidates[0].safety_ratings)
# Handle the error or return an appropriate message
return "Error: Unable to generate content Gemini API"
logger.info(f"Initializing Gemini model: {model_id}")
config = genai.GenerationConfig(temperature=0)
model = genai.GenerativeModel(model_id, generation_config=config)

safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
}

logger.info("Generating response with Gemini")
response = model.generate_content(prompt, safety_settings=safety_settings)

try:
if response.text:
logger.info("Successfully generated response")
return response.text
else:
error_msg = f"Empty response from Gemini model {model_id}"
logger.error(error_msg)
raise ValueError(error_msg)
except ValueError:
logger.error("Failed to get response text")
logger.error(f"Prompt feedback: {response.prompt_feedback}")
logger.error(f"Finish reason: {response.candidates[0].finish_reason}")
logger.error(f"Safety ratings: {response.candidates[0].safety_ratings}")
return "Error: Unable to generate content with Gemini API"

except Exception as e:
error_msg = f"Error during Gemini inference: {str(e)}"
logger.error(error_msg)
raise ValueError(error_msg)
77 changes: 77 additions & 0 deletions tests/test_gemini_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Tests for Gemini client implementation.
"""
import pytest
from unittest.mock import MagicMock, patch
from src.llm.gemini_client import Gemini

@pytest.fixture
def mock_config():
with patch('src.llm.gemini_client.config') as mock:
mock.get_gemini_api_key.return_value = "test-api-key"
yield mock

@pytest.fixture
def mock_genai():
with patch('src.llm.gemini_client.genai') as mock:
yield mock

@pytest.fixture
def gemini_client(mock_config, mock_genai):
return Gemini()

def test_init_with_api_key(mock_config, mock_genai):
"""Test client initialization with API key."""
client = Gemini()
mock_genai.configure.assert_called_once_with(api_key="test-api-key")

def test_init_without_api_key(mock_config, mock_genai):
"""Test client initialization without API key."""
mock_config.get_gemini_api_key.return_value = None
with pytest.raises(ValueError, match="Gemini API key not found in configuration"):
Gemini()

def test_init_config_failure(mock_config, mock_genai):
"""Test handling of configuration failure."""
mock_genai.configure.side_effect = Exception("Test error")
with pytest.raises(ValueError, match="Failed to configure Gemini client: Test error"):
Gemini()

def test_inference_success(mock_genai, gemini_client):
"""Test successful text generation."""
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = "Generated response"
mock_model.generate_content.return_value = mock_response
mock_genai.GenerativeModel.return_value = mock_model

response = gemini_client.inference("gemini-pro", "Test prompt")
assert response == "Generated response"
mock_model.generate_content.assert_called_once_with("Test prompt", safety_settings={
mock_genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: mock_genai.types.HarmBlockThreshold.BLOCK_NONE,
mock_genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT: mock_genai.types.HarmBlockThreshold.BLOCK_NONE,
})

def test_inference_empty_response(mock_genai, gemini_client):
"""Test handling of empty response."""
mock_model = MagicMock()
mock_response = MagicMock()
mock_response.text = None
mock_model.generate_content.return_value = mock_response
mock_genai.GenerativeModel.return_value = mock_model

with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"):
gemini_client.inference("gemini-pro", "Test prompt")

def test_inference_error(mock_genai, gemini_client):
"""Test handling of inference error."""
mock_model = MagicMock()
mock_model.generate_content.side_effect = Exception("Test error")
mock_genai.GenerativeModel.return_value = mock_model

with pytest.raises(ValueError, match="Error: Unable to generate content Gemini API"):
gemini_client.inference("gemini-pro", "Test prompt")

def test_str_representation(gemini_client):
"""Test string representation."""
assert str(gemini_client) == "Gemini"