Add exception handling for HTTP requests in Neon_API_V2 class and update database-related methods in DatabaseResource class

This commit is contained in:
2024-01-14 08:11:51 -05:00
parent a8b9e03318
commit a9c08ecc6d
4 changed files with 89 additions and 48 deletions
+7
View File
@@ -0,0 +1,7 @@
from requests.exceptions import HTTPError
class NeonClientException(HTTPError):
"""Base exception class for all exceptions raised by the Neon Client."""
pass
+5 -1
View File
@@ -3,6 +3,7 @@ import requests
from pydantic import BaseModel
from .openapi_models import *
from .exceptions import NeonClientException
from .__version__ import __version__
@@ -52,7 +53,10 @@ class Neon_API_V2:
if check_status_code:
# TODO: add custom exception classes here.
r.raise_for_status()
try:
r.raise_for_status()
except:
raise NeonClientException(r.text)
if response_model:
if response_is_array:
+65 -20
View File
@@ -18,7 +18,7 @@ from .openapi_models import (
Database2,
DatabaseCreateRequest,
)
from .utils import validate_with_model
from .utils import validate_obj_model
class Resource:
@@ -172,8 +172,36 @@ class ProjectResource(Resource):
class DatabaseResource(Resource):
path = "databases"
response_model = DatabasesResponse
response_model_single = DatabaseResponse
CreateRequest = DatabaseCreateRequest
UpdateRequest = DatabaseUpdateRequest
Database = Database
def _extract_database(self, obj):
"""Extract a database from the specified object."""
assert isinstance(obj, (DatabaseCreateRequest, Database1, Database2, Database))
# Object mappings.
if isinstance(obj, DatabaseCreateRequest):
obj = obj.database.model_dump()
if isinstance(obj, Database1):
obj = obj.database.model_dump()
if isinstance(obj, Database2):
obj = obj.database.model_dump()
if isinstance(obj, Database):
obj = obj.model_dump()
return obj
def get_databases(self, project_id: str, branch_id: str):
"""Get a list of databases."""
"""Get a list of databases.
See also: https://api-docs.neon.tech/reference/listprojectbranchdatabases
"""
databases_response = self.api.request(
method="GET",
@@ -185,48 +213,65 @@ class DatabaseResource(Resource):
return databases_response.databases
def get_database(self, project_id: str, database_id: str):
def get_database(
self,
project_id: str,
branch_id: str,
database_name: str,
):
"""Get a database."""
database_response = self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "databases", database_id),
path=self.api.url_join(
"projects",
project_id,
"branches",
branch_id,
"databases",
database_name,
),
response_model=DatabaseResponse,
)
return database_response.database
@validate_with_model(DatabaseCreateRequest)
def create_database(
self, project_id: str, branch_id: str, *, obj: DatabaseCreateRequest, **kwargs
self,
project_id: str,
branch_id: str,
db: DatabaseCreateRequest | Database1 | Database2 | Database,
):
"""Create a new database."""
# TODO: untested.
db = self._extract_database(db)
database_create_response = self.api.request(
method="POST",
path=self.api.url_join(
"projects", project_id, "branches", branch_id, "databases"
),
json=obj.model_dump(),
json=DatabaseCreateRequest(database=db).model_dump(),
response_model=DatabaseResponse,
)
return database_create_response.database
def update_database(self, project_id: str, database: Database2):
def update_database(
self,
project_id: str,
branch_id: str,
database_id: str,
db: DatabaseUpdateRequest | Database2,
):
"""Update a database."""
# TODO: This is not working yet.
payload = DatabaseUpdateRequest(database=database.model_dump())
db = self._extract_database(db)
return self.api.request(
database_update_response = self.api.request(
method="PATCH",
path=self.api.url_join("projects", project_id, "databases", database.id),
json=payload.model_dump(),
path=self.api.url_join(
"projects", project_id, "branches", branch_id, "databases", database_id
),
json=DatabaseUpdateRequest(database=db).model_dump(),
response_model=DatabaseResponse,
)
def new(self, name, **kwargs):
"""Create a new database."""
db = Database1.model_construct(name=name, **kwargs)
return db
return database_update_response.database
+12 -27
View File
@@ -3,38 +3,23 @@ from typing import List, Dict, Any, Union, Optional
from pydantic import BaseModel
def validate_with_model(*models):
"""a decorator that will use the Pydantic model to parse and validate the input"""
def validate_obj_model(parameter_name: str, model: BaseModel):
"""A decorator that validates the 'obj' argument against the specified model."""
def decorator(func):
def wrapper(*args, **kwargs):
# Merge args and kwargs into a single dict
sig = inspect.signature(func)
bound_args = sig.bind_partial(*args, **kwargs)
bound_args.apply_defaults()
keyword_args = bound_args.arguments.pop("kwargs")
# Get the 'obj' argument.
obj = kwargs.get(parameter_name)
# creating Pydantic classes
for name, param in sig.parameters.items():
if issubclass(param.annotation, BaseModel):
for model in models:
if model == param.annotation:
model_dict = {
key: value
for key, value in keyword_args.items()
if key in model.model_fields
}
model_object = model(**model_dict)
keyword_args[name] = model_object
keyword_args = {
key: value
for key, value in keyword_args.items()
if key not in model_dict.keys()
}
break
# If the 'obj' argument is not provided, raise an exception.
if obj is None:
raise ValueError(f"Missing required argument '{parameter_name}'.")
# Pass the model instance(s) to the wrapped function
return func(*bound_args.args, **keyword_args)
# Validate the 'obj' argument against the specified model.
model.model_validate(obj)
# Call the wrapped function.
return func(*args, **kwargs)
return wrapper