Merge pull request #23 from barisozmen/issue_15

Add default kwargs logic into Anthropic provider, which is superseded by user entered kwargs
This commit is contained in:
2024-10-31 16:00:46 -04:00
committed by GitHub
+7 -5
View File
@@ -14,11 +14,13 @@ T = TypeVar("T", bound=BaseModel)
PROVIDER_NAME = "anthropic"
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
DEFAULT_MAX_TOKENS = 1000
DEFAULT_KWARGS = {"max_tokens": DEFAULT_MAX_TOKENS}
class Anthropic(BaseProvider):
NAME = PROVIDER_NAME
DEFAULT_MODEL = DEFAULT_MODEL
DEFAULT_KWARGS = DEFAULT_KWARGS
def __init__(self, api_key: str | None = None):
self.api_key = api_key or settings.get_api_key(PROVIDER_NAME)
@@ -46,8 +48,7 @@ class Anthropic(BaseProvider):
response = self.client.messages.create(
model=conversation.llm_model or self.DEFAULT_MODEL,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
# Get the response content from the Anthropic response
@@ -64,7 +65,9 @@ class Anthropic(BaseProvider):
def structured_response(self, model: str, response_model: Type[T], **kwargs) -> T:
response = self.structured_client.messages.create(
model=model or self.DEFAULT_MODEL, response_model=response_model, **kwargs
model=model or self.DEFAULT_MODEL,
response_model=response_model,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response
@@ -76,8 +79,7 @@ class Anthropic(BaseProvider):
response = self.client.messages.create(
model=llm_model or self.DEFAULT_MODEL,
messages=messages,
max_tokens=DEFAULT_MAX_TOKENS,
**kwargs,
**{**self.DEFAULT_KWARGS, **kwargs},
)
return response.content[0].text