Skip to content

Commit

Permalink
fix: Implement rate limit handling with 60s wait
Browse files Browse the repository at this point in the history
- Add 60-second wait on HTTP 429 responses
- Properly handle Groq API rate limit errors
- Add comprehensive test coverage
- Improve error messaging and logging

Fixes stitionai#524

Co-Authored-By: Erkin Alp Güney <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and erkinalp committed Dec 18, 2024
1 parent 9ec4699 commit 73c14f2
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 17 deletions.
34 changes: 22 additions & 12 deletions src/llm/groq_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from groq import Groq as _Groq
import requests
from requests.exceptions import HTTPError

from src.config import Config

Expand All @@ -10,15 +12,23 @@ def __init__(self):
self.client = _Groq(api_key=api_key)

def inference(self, model_id: str, prompt: str) -> str:
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt.strip(),
}
],
model=model_id,
temperature=0
)

return chat_completion.choices[0].message.content
try:
chat_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt.strip(),
}
],
model=model_id,
temperature=0
)
return chat_completion.choices[0].message.content
except Exception as e:
# Convert Groq API errors to HTTPError for consistent handling
if "rate limit" in str(e).lower():
response = requests.Response()
response.status_code = 429
response._content = str(e).encode()
raise HTTPError(response=response)
raise
19 changes: 14 additions & 5 deletions src/services/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# create wrapper function that will has retry logic of 5 times
import sys
import time
import requests
from functools import wraps
import json

Expand All @@ -11,9 +12,17 @@ def wrapper(*args, **kwargs):
max_tries = 5
tries = 0
while tries < max_tries:
result = func(*args, **kwargs)
if result:
return result
try:
result = func(*args, **kwargs)
if result:
return result
except requests.exceptions.HTTPError as e:
if e.response.status_code == 429:
print("Rate limit reached, waiting 60 seconds...")
emit_agent("info", {"type": "warning", "message": "Rate limit reached, waiting 60 seconds..."})
time.sleep(60)
continue
raise
print("Invalid response from the model, I'm trying again...")
emit_agent("info", {"type": "warning", "message": "Invalid response from the model, trying again..."})
tries += 1
Expand All @@ -25,7 +34,7 @@ def wrapper(*args, **kwargs):
return False
return wrapper


class InvalidResponseError(Exception):
pass

Expand Down Expand Up @@ -87,4 +96,4 @@ def wrapper(*args, **kwargs):
# raise InvalidResponseError("Failed to parse response as JSON")
return False

return wrapper
return wrapper
39 changes: 39 additions & 0 deletions tests/test_groq_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from unittest.mock import Mock, patch
from requests.exceptions import HTTPError

from src.llm.groq_client import Groq


def test_groq_rate_limit_handling():
groq = Groq()

# Mock the Groq client to simulate rate limit error
mock_client = Mock()
mock_client.chat.completions.create.side_effect = Exception(
'Rate limit reached for model `mixtral-8x7b-32768`. Please try again in 7.164s.'
)
groq.client = mock_client

# Test that rate limit error is converted to HTTPError
with pytest.raises(HTTPError) as exc_info:
groq.inference("mixtral-8x7b-32768", "test prompt")

assert exc_info.value.response.status_code == 429
assert "rate limit" in str(exc_info.value.response.content.decode()).lower()


def test_groq_other_error_handling():
groq = Groq()

# Mock the Groq client to simulate other error
mock_client = Mock()
mock_client.chat.completions.create.side_effect = Exception("Some other error")
groq.client = mock_client

# Test that other errors are re-raised as-is
with pytest.raises(Exception) as exc_info:
groq.inference("mixtral-8x7b-32768", "test prompt")

assert "Some other error" in str(exc_info.value)
assert not isinstance(exc_info.value, HTTPError)
42 changes: 42 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import time
import requests
from unittest.mock import Mock, patch
from src.services.utils import retry_wrapper

def test_retry_wrapper_rate_limit():
# Mock a function that raises rate limit error
@retry_wrapper
def rate_limited_func():
response = Mock(spec=requests.Response)
response.status_code = 429
response.json.return_value = {
'error': {
'message': 'Rate limit reached',
'type': 'tokens',
'code': 'rate_limit_exceeded'
}
}
raise requests.exceptions.HTTPError(response=response)

# Test that it waits 60 seconds on rate limit
with patch('time.sleep') as mock_sleep:
with pytest.raises(requests.exceptions.HTTPError):
rate_limited_func()
# Verify it attempted to sleep for 60 seconds
assert mock_sleep.call_args[0][0] == 60

def test_retry_wrapper_other_errors():
# Mock a function that raises other HTTP errors
@retry_wrapper
def other_error_func():
response = Mock(spec=requests.Response)
response.status_code = 500
raise requests.exceptions.HTTPError(response=response)

# Test that it retries with default backoff
with patch('time.sleep') as mock_sleep:
with pytest.raises(requests.exceptions.HTTPError):
other_error_func()
# Verify it used shorter retry delays
assert all(call[0][0] < 60 for call in mock_sleep.call_args_list)

0 comments on commit 73c14f2

Please sign in to comment.