This commit is contained in:
Jason Liu
2024-01-30 20:23:06 -05:00
parent 42e3301342
commit 6181f791f9
+36
View File
@@ -219,3 +219,39 @@ Since everything is already annotated with Pydantic, this code is very simple to
!!! warning "Where do tags come from?"
I just want to call out that here you can also imagine the tag spec IDs and names and instructions for example could come from a database or somewhere else. I'll leave this as an exercise to the reader, but I hope this gives us a clear understanding of how we can do something like user-defined classification.
## Improving the Model
There's a couple things we could do to make this system a little bit more robust.
1. Use confidence score:
```python
class TagWithConfidence(Tag):
confidence: float = Field(..., ge=0, le=1, description="The confidence of the prediction, 0 is low, 1 is high")
```
2. Use multiclass classification:
Notice in the example we use Iterable[Tag] vs Tag. This is because we might want to use a multiclass classification model that returns multiple tag!
```python
await client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{
"role": "system",
"content": "You are a world-class text tagging system.",
},
{
"role": "user",
"content": f"Describe the following text: `{text}`"},
{
"role": "user",
"content": f"Here are the allowed tags: {allowed_tags_str}",
},
],
response_model=Iterable[Tag],
validation_context={"tags": tags},
)
```