This commit is contained in:
Jason Liu
2024-03-28 18:14:14 -04:00
parent 16c36ca96c
commit 6d23442d67
6 changed files with 41 additions and 8 deletions
+32
View File
@@ -0,0 +1,32 @@
import pytest
from examples.planning.run import extract_person, extract_people, Person
@pytest.mark.asyncio
async def test_extract_person():
# Test the extract_person function with a known input
text = "John is 45 years old"
expected_person = Person(name="John", age=45)
person = await extract_person(text)
assert (
person == expected_person
), "The extracted person does not match the expected person"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"names_and_ages, expected_people",
[
(
["Alice is 30 years old", "Bob is 24 years old"],
[Person(name="Alice", age=30), Person(name="Bob", age=24)],
)
],
)
async def test_extract_people(names_and_ages, expected_people):
# Test the extract_people function with a list of known inputs
people = await extract_people(names_and_ages)
assert (
people == expected_people
), "The extracted people do not match the expected people"
+1
View File
@@ -50,6 +50,7 @@ client = patch(AsyncOpenAI())
class GeneratedSummary(BaseModel):
summary: str
async def summarize_text(text: str):
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
+2 -1
View File
@@ -53,6 +53,7 @@ class GeneratedSummary(BaseModel):
)
summary: str
async def summarize_text(text: str):
response = await client.chat.completions.create(
model="gpt-3.5-turbo",
@@ -94,4 +95,4 @@ if __name__ == "__main__":
Source: de, Summary: de, Match: True, Detected: de
Source: hi, Summary: hi, Match: True, Detected: hi
Source: ja, Summary: ja, Match: True, Detected: ja
"""
"""
+2 -2
View File
@@ -75,7 +75,7 @@ def _add_params(
field_type = details.get(
"type", "unknown"
) # Might be better to fail here if there is no type since pydantic models require types
if "array" in field_type and "items" not in details:
raise ValueError("Invalid array item.")
@@ -83,7 +83,7 @@ def _add_params(
if "array" in field_type and "$ref" in details["items"]:
type_element.text = f"List[{details['title']}]"
list_found = True
nested_list_found = True
nested_list_found = True
# Check for non-nested List
elif "array" in field_type and "type" in details["items"]:
type_element.text = f"List[{details['items']['type']}]"
-1
View File
@@ -69,7 +69,6 @@ class OpenAISchema(BaseModel): # type: ignore[misc]
for line in parseString(json_to_xml(cls)).toprettyxml().splitlines()[1:]
)
@classmethod
def from_response(
cls,
+4 -4
View File
@@ -68,9 +68,9 @@ def test_list():
name: str
age: int
family: List[str]
resp = create(
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
model="claude-3-opus-20240229", # Fails with claude-3-haiku-20240307
max_tokens=1024,
max_retries=0,
messages=[
@@ -81,13 +81,13 @@ def test_list():
],
response_model=User,
)
assert isinstance(resp, User)
assert isinstance(resp.family, List)
for member in resp.family:
assert isinstance(member, str)
def test_nested_list():
class Properties(BaseModel):
key: str