Skip to content

Commit

Permalink
fix: Fix a regression in the generate extension. (#3)
Browse files Browse the repository at this point in the history
* fix import to silence pylance warning

* fix generate extension

* fix tests
  • Loading branch information
masci authored May 3, 2024
1 parent 1cbd2b5 commit b2d5805
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
14 changes: 9 additions & 5 deletions src/banks/extensions/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ def _generate(self, text, model_name=DEFAULT_MODEL):
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, completion(model=model_name, messages=messages))
content: str = response["choices"][0]["message"]["content"]
if SYSTEM_PROMPT.canary_leaked(content):
msg = "The system prompt has leaked into the response, possible prompt injection!"
raise CanaryWordError(msg)
return self._get_content(response)

async def _agenerate(self, text, model_name=DEFAULT_MODEL):
"""
Expand All @@ -78,4 +75,11 @@ async def _agenerate(self, text, model_name=DEFAULT_MODEL):
{"role": "user", "content": text},
]
response: ModelResponse = cast(ModelResponse, await acompletion(model=model_name, messages=messages))
return response["choices"][0]["message"]["content"]
return self._get_content(response)

def _get_content(self, response: ModelResponse) -> str:
content = response["choices"][0]["message"]["content"]
if SYSTEM_PROMPT.canary_leaked(content):
msg = "The system prompt has leaked into the response, possible prompt injection!"
raise CanaryWordError(msg)
return content
2 changes: 1 addition & 1 deletion src/banks/filters/lemmatize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from banks.errors import MissingDependencyError

try:
from simplemma import text_lemmatizer
from simplemma.simplemma import text_lemmatizer

simplemma_avail = True
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion src/banks/templates/generate_tweet.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ Generate a tweet about the topic {{ topic }} with a positive sentiment.
#}
Examples:
{% for number in range(3) %}
- {% generate "write a tweet with positive sentiment" "gpt-3.5-turbo" %}
- {% generate "write a tweet with positive sentiment", "gpt-3.5-turbo" %}

{% endfor %}

0 comments on commit b2d5805

Please sign in to comment.