mirror of
https://github.com/kennethreitz/instructor.git
synced 2026-06-05 22:50:18 +00:00
stream evals of structure
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from pprint import pprint
|
||||
|
||||
from pydantic import ValidationError
|
||||
from summary_stats import StreamingAccumulatorManager
|
||||
import models as m
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
IS_JSON = "_is_json_"
|
||||
IS_VALID = "_is_valid_"
|
||||
VALIDATION_ERROR = "_validation_error_"
|
||||
|
||||
|
||||
def process_line(eval_manager, line, index):
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
eval_manager.accumulator[Status.IS_JSON.value].update(index, True)
|
||||
|
||||
try:
|
||||
obj = m.MultiSearch.model_validate(obj)
|
||||
eval_manager.update(index, obj.model_dump())
|
||||
eval_manager.accumulator[Status.IS_VALID.value].update(index, True)
|
||||
|
||||
except ValidationError as e:
|
||||
eval_manager.accumulator[Status.IS_VALID.value].update(index, False)
|
||||
process_validation_error(eval_manager, e, index)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
eval_manager.accumulator[Status.IS_JSON.value].update(index, False)
|
||||
|
||||
|
||||
def process_validation_error(eval_manager, error, index):
|
||||
for err in error.errors():
|
||||
path = (
|
||||
"$."
|
||||
+ ".".join(
|
||||
[str(x) if not isinstance(x, int) else "[*]" for x in err["loc"]]
|
||||
)
|
||||
+ "."
|
||||
+ err["type"]
|
||||
)
|
||||
eval_manager.accumulator[Status.VALIDATION_ERROR.value].update(index, path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
eval_manager = StreamingAccumulatorManager()
|
||||
|
||||
with open("test.jsonl") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for ii, line in enumerate(lines):
|
||||
process_line(eval_manager, line, ii)
|
||||
|
||||
pprint(eval_manager.summarize())
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
CRM = "CRM"
|
||||
WEB = "WEB"
|
||||
EMAIL = "EMAIL"
|
||||
SOCIAL_MEDIA = "SOCIAL_MEDIA"
|
||||
OTHER = "OTHER"
|
||||
|
||||
|
||||
class Search(BaseModel):
|
||||
query: str
|
||||
source_type: SourceType
|
||||
results_limit: Optional[int] = Field(10)
|
||||
is_priority: Optional[bool] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
class MultiSearch(BaseModel):
|
||||
queries: List[Search]
|
||||
user_id: Optional[str]
|
||||
@@ -0,0 +1,132 @@
|
||||
# Modified StreamingAccumulator class with self.value and self.str_length as lists
|
||||
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Union, List
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StreamingAccumulator:
|
||||
counter: Counter = Field(default_factory=Counter)
|
||||
min: float = float("inf")
|
||||
max: float = float("-inf")
|
||||
sum: float = 0
|
||||
squared_sum: float = 0
|
||||
unique_values: set = Field(default_factory=set)
|
||||
missing_values: int = 0
|
||||
str_min_length: float = float("inf")
|
||||
str_max_length: float = float("-inf")
|
||||
str_sum_length: float = 0
|
||||
str_squared_sum_length: float = 0
|
||||
value: List[Any] = Field(default_factory=list) # Added back as a list
|
||||
str_length: List[int] = Field(default_factory=list) # Added back as a list
|
||||
reverse_lookup: defaultdict = defaultdict(list)
|
||||
|
||||
def __init__(self):
|
||||
self.counter = Counter()
|
||||
self.min = float("inf")
|
||||
self.max = float("-inf")
|
||||
self.sum = 0
|
||||
self.squared_sum = 0
|
||||
self.unique_values = set()
|
||||
self.missing_values = 0
|
||||
self.str_min_length = float("inf")
|
||||
self.str_max_length = float("-inf")
|
||||
self.str_sum_length = 0
|
||||
self.str_squared_sum_length = 0
|
||||
self.value = []
|
||||
self.str_length = []
|
||||
self.reverse_lookup = defaultdict(list)
|
||||
|
||||
def update(self, index: Any, value: Any) -> None:
|
||||
"""Update statistics with a new value."""
|
||||
|
||||
if isinstance(value, (int, str, bool)):
|
||||
self.counter[value] += 1
|
||||
self.unique_values.add(value)
|
||||
self.value.append(value)
|
||||
self.reverse_lookup[value].append(index)
|
||||
|
||||
if value is None or value == "":
|
||||
self.missing_values += 1
|
||||
return
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
self.min = min(self.min, value)
|
||||
self.max = max(self.max, value)
|
||||
self.sum += value
|
||||
self.squared_sum += value**2
|
||||
|
||||
if isinstance(value, str):
|
||||
str_len = len(value)
|
||||
self.str_length.append(str_len) # Append the string length to the list
|
||||
self.str_min_length = min(self.str_min_length, str_len)
|
||||
self.str_max_length = max(self.str_max_length, str_len)
|
||||
self.str_sum_length += str_len
|
||||
self.str_squared_sum_length += str_len**2
|
||||
|
||||
def summarize(self, key_name=None) -> Dict[str, Union[int, float, dict]]:
|
||||
if key_name is None:
|
||||
key_name = ""
|
||||
|
||||
n = sum(self.counter.values())
|
||||
summaries = {}
|
||||
summaries["counter"] = self.counter
|
||||
summaries["unique_count"] = len(self.unique_values)
|
||||
summaries["missing_values"] = self.missing_values
|
||||
summaries["_reverse_lookup"] = dict(self.reverse_lookup)
|
||||
|
||||
if n > 0:
|
||||
if all(isinstance(value, (bool)) for value in self.unique_values):
|
||||
summaries["mean"] = self.sum / n
|
||||
return summaries
|
||||
|
||||
if all(isinstance(value, (int, float)) for value in self.unique_values):
|
||||
summaries["min"] = self.min
|
||||
summaries["max"] = self.max
|
||||
summaries["mean"] = self.sum / n
|
||||
summaries["std"] = np.sqrt(self.squared_sum / n - (self.sum / n) ** 2)
|
||||
return summaries
|
||||
|
||||
if all(
|
||||
isinstance(value, str) for value in self.unique_values
|
||||
) and not key_name.startswith("_"):
|
||||
summaries["str_min_length"] = self.str_min_length
|
||||
summaries["str_max_length"] = self.str_max_length
|
||||
summaries["str_mean_length"] = self.str_sum_length / n
|
||||
summaries["str_std_length"] = np.sqrt(
|
||||
self.str_squared_sum_length / n - (self.str_sum_length / n) ** 2
|
||||
)
|
||||
return summaries
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
class StreamingAccumulatorManager:
|
||||
def __init__(self):
|
||||
self.accumulator = defaultdict(StreamingAccumulator)
|
||||
|
||||
def update(self, index, data: Any, path: str = "$") -> None:
|
||||
"""Accumulate values from a nested object."""
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
new_path = f"{path}.{key}"
|
||||
self.update(index, value, new_path)
|
||||
elif isinstance(data, list):
|
||||
new_path = f"{path}[*]"
|
||||
for value in data:
|
||||
self.update(index, value, new_path)
|
||||
length_path = f"{path}.length"
|
||||
self.accumulator[length_path].update(index, len(data))
|
||||
elif isinstance(data, Enum):
|
||||
enum_path = f"{path}.enum"
|
||||
self.accumulator[enum_path].update(index, data.value)
|
||||
elif path != "$":
|
||||
pass
|
||||
else:
|
||||
self.accumulator[path].update(index, data)
|
||||
|
||||
def summarize(self) -> Dict[str, Dict]:
|
||||
"""Generate summary statistics for all paths."""
|
||||
return {k: v.summarize(key_name=k) for k, v in self.accumulator.items()}
|
||||
@@ -0,0 +1,18 @@
|
||||
{"queries": [{"query": "sales Q1", "source_type": "CRM"}], "user_id": "user_1"}
|
||||
{"queries": [{"query": "customer churn", "source_type": "WEB", "is_priority": true}], "user_id": "user_2", "total_queries": 1}
|
||||
{"queries": ["query": "email campaigns", "source_type": "EMAIL"}, {"query": "social ads", "source_type": "SOCIAL_MEDIA"}], "user_id": "user_3", "total_queries": 2}
|
||||
{"queries": [{"query": "sales Q2", "source_type": "INVALID_ENUM"}], "user_id": "user_4"}
|
||||
{queries: [{"query": "sales Q3", "source_type": "CRM"}], "user_id": "user_5"}
|
||||
{"queries": [{"query": "sales Q4", "source_type": "CRM", "timestamp": "2023-09-10T12:00:00Z"}], "total_queries": 1}
|
||||
{"queries": [{"query": "customer retention", "source_type": "EMAIL", "is_priority": "should_be_bool"}], "user_id": "user_6"}
|
||||
{"queries": [{"query": "sales Q1", "source_type": "CRM"}, {"query": "sales Q2", "source_type": "WEB"}], "user_id": "user_7", "total_queries": 2}
|
||||
{"queries": [{"query": "sales Q1", "source_type": "CRM", "timestamp": "2023-09-10T12:00:00Z"}], "user_id": "user_8", "total_queries": 1}
|
||||
{"queries": [{"query": "revenue 2022", "source_type": "WEB", "results_limit": 10, "is_priority": true}], "user_id": "user_9", "total_queries": 1}
|
||||
{"queries": [{"query": "email outreach", "source_type": "EMAIL", "tags": ["outreach", "2022"]}], "user_id": "user_10", "total_queries": 1}
|
||||
{"queries": [{"query": "product sales", "source_type": "CRM"}, {"query": "customer satisfaction", "source_type": "EMAIL"}], "user_id": "user_11", "total_queries": 2}
|
||||
{"queries": [{"query": "social impact", "source_type": "SOCIAL_MEDIA"}, {"query": "email campaigns", "source_type": "EMAIL"}, {"query": "web traffic", "source_type": "WEB"}], "user_id": "user_12", "total_queries": 3}
|
||||
{"queries": [{"query": "sales Q1", "source_type": "CRM", "is_priority": false}], "user_id": "user_13", "total_queries": 1}
|
||||
{"queries": [{"query": "marketing strategies", "source_type": "WEB", "results_limit": 15, "is_priority": true, "tags": ["strategy", "2023"]}], "user_id": "user_14", "total_queries": 1}
|
||||
{"queries": [{"query": "customer feedback", "source_type": "EMAIL", "tags": ["feedback"]}], "user_id": "user_15", "total_queries": 1}
|
||||
{"queries": [{"query": "revenue streams", "source_type": "CRM"}, {"query": "new products", "source_type": "WEB"}], "user_id": "user_16", "total_queries": 2}
|
||||
{"queries": [{"query": "social trends", "source_type": "SOCIAL_MEDIA", "is_priority": true}, {"query": "email open rates", "source_type": "EMAIL", "results_limit": 5}, {"query": "website analytics", "source_type": "WEB", "tags": ["analytics"]}], "user_id": "user_17", "total_queries": 3}
|
||||
Reference in New Issue
Block a user