mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
ruff
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from examples.planning.run import extract_person, extract_people, Person
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_person():
|
||||
# Test the extract_person function with a known input
|
||||
text = "John is 45 years old"
|
||||
expected_person = Person(name="John", age=45)
|
||||
person = await extract_person(text)
|
||||
assert (
|
||||
person == expected_person
|
||||
), "The extracted person does not match the expected person"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"names_and_ages, expected_people",
|
||||
[
|
||||
(
|
||||
["Alice is 30 years old", "Bob is 24 years old"],
|
||||
[Person(name="Alice", age=30), Person(name="Bob", age=24)],
|
||||
)
|
||||
],
|
||||
)
|
||||
async def test_extract_people(names_and_ages, expected_people):
|
||||
# Test the extract_people function with a list of known inputs
|
||||
people = await extract_people(names_and_ages)
|
||||
assert (
|
||||
people == expected_people
|
||||
), "The extracted people do not match the expected people"
|
||||
@@ -50,6 +50,7 @@ client = patch(AsyncOpenAI())
|
||||
class GeneratedSummary(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
async def summarize_text(text: str):
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
|
||||
@@ -53,6 +53,7 @@ class GeneratedSummary(BaseModel):
|
||||
)
|
||||
summary: str
|
||||
|
||||
|
||||
async def summarize_text(text: str):
|
||||
response = await client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
@@ -94,4 +95,4 @@ if __name__ == "__main__":
|
||||
Source: de, Summary: de, Match: True, Detected: de
|
||||
Source: hi, Summary: hi, Match: True, Detected: hi
|
||||
Source: ja, Summary: ja, Match: True, Detected: ja
|
||||
"""
|
||||
"""
|
||||
|
||||
@@ -75,7 +75,7 @@ def _add_params(
|
||||
field_type = details.get(
|
||||
"type", "unknown"
|
||||
) # Might be better to fail here if there is no type since pydantic models require types
|
||||
|
||||
|
||||
if "array" in field_type and "items" not in details:
|
||||
raise ValueError("Invalid array item.")
|
||||
|
||||
@@ -83,7 +83,7 @@ def _add_params(
|
||||
if "array" in field_type and "$ref" in details["items"]:
|
||||
type_element.text = f"List[{details['title']}]"
|
||||
list_found = True
|
||||
nested_list_found = True
|
||||
nested_list_found = True
|
||||
# Check for non-nested List
|
||||
elif "array" in field_type and "type" in details["items"]:
|
||||
type_element.text = f"List[{details['items']['type']}]"
|
||||
|
||||
@@ -69,7 +69,6 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
|
||||
for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:]
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_response(
|
||||
cls,
|
||||
|
||||
@@ -68,9 +68,9 @@ def test_list():
|
||||
name: str
|
||||
age: int
|
||||
family: List[str]
|
||||
|
||||
|
||||
resp = create(
|
||||
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
|
||||
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
|
||||
max_tokens=1024,
|
||||
max_retries=0,
|
||||
messages=[
|
||||
@@ -81,13 +81,13 @@ def test_list():
|
||||
],
|
||||
response_model=User,
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(resp, User)
|
||||
assert isinstance(resp.family, List)
|
||||
for member in resp.family:
|
||||
assert isinstance(member, str)
|
||||
|
||||
|
||||
|
||||
def test_nested_list():
|
||||
class Properties(BaseModel):
|
||||
key: str
|
||||
|
||||
Reference in New Issue
Block a user