mirror of
https://github.com/kennethreitz/langchain.git
synced 2026-06-05 23:00:18 +00:00
572968fee3
Using langchain input type
131 lines
3.0 KiB
Python
131 lines
3.0 KiB
Python
from typing import Optional
|
|
|
|
import pytest
|
|
|
|
try:
|
|
from pydantic.v1 import BaseModel, ValidationError
|
|
except ImportError:
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from langserve.validation import (
|
|
create_batch_request_model,
|
|
create_invoke_request_model,
|
|
create_runnable_config_model,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
{
|
|
"input": {"a": "qqq"},
|
|
"kwargs": {},
|
|
"valid": False,
|
|
},
|
|
{
|
|
"input": {"a": 2},
|
|
"kwargs": "hello",
|
|
"valid": False,
|
|
},
|
|
{
|
|
"input": {"a": 2},
|
|
"config": "hello",
|
|
"valid": False,
|
|
},
|
|
{
|
|
"input": {"b": "hello"},
|
|
"valid": False,
|
|
},
|
|
{
|
|
"input": {"a": 2, "b": "hello"},
|
|
"config": "hello",
|
|
"valid": False,
|
|
},
|
|
{
|
|
"input": {"a": 2, "b": "hello"},
|
|
"valid": True,
|
|
},
|
|
{
|
|
"input": {"a": 2, "b": "hello"},
|
|
"valid": True,
|
|
},
|
|
{
|
|
"input": {"a": 2},
|
|
"valid": True,
|
|
},
|
|
],
|
|
)
|
|
def test_create_invoke_and_batch_models(test_case: dict) -> None:
|
|
"""Test that the invoke request model is created correctly."""
|
|
|
|
class Input(BaseModel):
|
|
"""Test input."""
|
|
|
|
a: int
|
|
b: Optional[str] = None
|
|
|
|
valid = test_case.pop("valid")
|
|
config = create_runnable_config_model("test", ["tags"])
|
|
|
|
model = create_invoke_request_model("namespace", Input, config)
|
|
|
|
if valid:
|
|
model(**test_case)
|
|
else:
|
|
with pytest.raises(ValidationError):
|
|
model(**test_case)
|
|
|
|
# Validate batch request
|
|
# same structure as input request, but
|
|
# 'input' is a list of inputs and is called 'inputs'
|
|
batch_model = create_batch_request_model("namespace", Input, config)
|
|
|
|
test_case["inputs"] = [test_case.pop("input")]
|
|
if valid:
|
|
batch_model(**test_case)
|
|
else:
|
|
with pytest.raises(ValidationError):
|
|
batch_model(**test_case)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
{
|
|
"type": int,
|
|
"input": 1,
|
|
"valid": True,
|
|
},
|
|
{
|
|
"type": float,
|
|
"input": "name",
|
|
"valid": False,
|
|
},
|
|
{
|
|
"type": float,
|
|
"input": [3.2],
|
|
"valid": False,
|
|
},
|
|
{
|
|
"type": float,
|
|
"input": 1.1,
|
|
"valid": True,
|
|
},
|
|
{
|
|
"type": Optional[float],
|
|
"valid": True,
|
|
"input": None,
|
|
},
|
|
],
|
|
)
|
|
def test_validation(test_case) -> None:
|
|
"""Test that the invoke request model is created correctly."""
|
|
config = create_runnable_config_model("test", [])
|
|
model = create_invoke_request_model("namespace", test_case.pop("type"), config)
|
|
|
|
if test_case["valid"]:
|
|
model(**test_case)
|
|
else:
|
|
with pytest.raises(ValidationError):
|
|
model(**test_case)
|