From a9c08ecc6d5e85453d6e00ffb86517244847bbf1 Mon Sep 17 00:00:00 2001 From: Kenneth Reitz Date: Sun, 14 Jan 2024 08:11:51 -0500 Subject: [PATCH] Add exception handling for HTTP requests in Neon_API_V2 class and update database-related methods in DatabaseResource class --- neon_client/exceptions.py | 7 ++++ neon_client/http_client.py | 6 ++- neon_client/resources.py | 85 +++++++++++++++++++++++++++++--------- neon_client/utils.py | 39 ++++++----------- 4 files changed, 89 insertions(+), 48 deletions(-) create mode 100644 neon_client/exceptions.py diff --git a/neon_client/exceptions.py b/neon_client/exceptions.py new file mode 100644 index 0000000..5747844 --- /dev/null +++ b/neon_client/exceptions.py @@ -0,0 +1,7 @@ +from requests.exceptions import HTTPError + + +class NeonClientException(HTTPError): + """Base exception class for all exceptions raised by the Neon Client.""" + + pass diff --git a/neon_client/http_client.py b/neon_client/http_client.py index 10ddce1..4d4862e 100644 --- a/neon_client/http_client.py +++ b/neon_client/http_client.py @@ -3,6 +3,7 @@ import requests from pydantic import BaseModel from .openapi_models import * +from .exceptions import NeonClientException from .__version__ import __version__ @@ -52,7 +53,10 @@ class Neon_API_V2: if check_status_code: # TODO: add custom exception classes here. - r.raise_for_status() + try: + r.raise_for_status() + except: + raise NeonClientException(r.text) if response_model: if response_is_array: diff --git a/neon_client/resources.py b/neon_client/resources.py index 84af2fc..1feceac 100644 --- a/neon_client/resources.py +++ b/neon_client/resources.py @@ -18,7 +18,7 @@ from .openapi_models import ( Database2, DatabaseCreateRequest, ) -from .utils import validate_with_model +from .utils import validate_obj_model class Resource: @@ -172,8 +172,36 @@ class ProjectResource(Resource): class DatabaseResource(Resource): + path = "databases" + response_model = DatabasesResponse + response_model_single = DatabaseResponse + + CreateRequest = DatabaseCreateRequest + UpdateRequest = DatabaseUpdateRequest + Database = Database + + def _extract_database(self, obj): + """Extract a database from the specified object.""" + + assert isinstance(obj, (DatabaseCreateRequest, Database1, Database2, Database)) + + # Object mappings. + if isinstance(obj, DatabaseCreateRequest): + obj = obj.database.model_dump() + if isinstance(obj, Database1): + obj = obj.database.model_dump() + if isinstance(obj, Database2): + obj = obj.database.model_dump() + if isinstance(obj, Database): + obj = obj.model_dump() + + return obj + def get_databases(self, project_id: str, branch_id: str): - """Get a list of databases.""" + """Get a list of databases. + + See also: https://api-docs.neon.tech/reference/listprojectbranchdatabases + """ databases_response = self.api.request( method="GET", @@ -185,48 +213,65 @@ class DatabaseResource(Resource): return databases_response.databases - def get_database(self, project_id: str, database_id: str): + def get_database( + self, + project_id: str, + branch_id: str, + database_name: str, + ): """Get a database.""" database_response = self.api.request( method="GET", - path=self.api.url_join("projects", project_id, "databases", database_id), + path=self.api.url_join( + "projects", + project_id, + "branches", + branch_id, + "databases", + database_name, + ), response_model=DatabaseResponse, ) return database_response.database - @validate_with_model(DatabaseCreateRequest) def create_database( - self, project_id: str, branch_id: str, *, obj: DatabaseCreateRequest, **kwargs + self, + project_id: str, + branch_id: str, + db: DatabaseCreateRequest | Database1 | Database2 | Database, ): """Create a new database.""" - # TODO: untested. + + db = self._extract_database(db) database_create_response = self.api.request( method="POST", path=self.api.url_join( "projects", project_id, "branches", branch_id, "databases" ), - json=obj.model_dump(), + json=DatabaseCreateRequest(database=db).model_dump(), response_model=DatabaseResponse, ) return database_create_response.database - def update_database(self, project_id: str, database: Database2): + def update_database( + self, + project_id: str, + branch_id: str, + database_id: str, + db: DatabaseUpdateRequest | Database2, + ): """Update a database.""" - # TODO: This is not working yet. - payload = DatabaseUpdateRequest(database=database.model_dump()) + db = self._extract_database(db) - return self.api.request( + database_update_response = self.api.request( method="PATCH", - path=self.api.url_join("projects", project_id, "databases", database.id), - json=payload.model_dump(), + path=self.api.url_join( + "projects", project_id, "branches", branch_id, "databases", database_id + ), + json=DatabaseUpdateRequest(database=db).model_dump(), response_model=DatabaseResponse, ) - - def new(self, name, **kwargs): - """Create a new database.""" - - db = Database1.model_construct(name=name, **kwargs) - return db + return database_update_response.database diff --git a/neon_client/utils.py b/neon_client/utils.py index d7d790e..44eca33 100644 --- a/neon_client/utils.py +++ b/neon_client/utils.py @@ -3,38 +3,23 @@ from typing import List, Dict, Any, Union, Optional from pydantic import BaseModel -def validate_with_model(*models): - """a decorator that will use the Pydantic model to parse and validate the input""" +def validate_obj_model(parameter_name: str, model: BaseModel): + """A decorator that validates the 'obj' argument against the specified model.""" def decorator(func): def wrapper(*args, **kwargs): - # Merge args and kwargs into a single dict - sig = inspect.signature(func) - bound_args = sig.bind_partial(*args, **kwargs) - bound_args.apply_defaults() - keyword_args = bound_args.arguments.pop("kwargs") + # Get the 'obj' argument. + obj = kwargs.get(parameter_name) - # creating Pydantic classes - for name, param in sig.parameters.items(): - if issubclass(param.annotation, BaseModel): - for model in models: - if model == param.annotation: - model_dict = { - key: value - for key, value in keyword_args.items() - if key in model.model_fields - } - model_object = model(**model_dict) - keyword_args[name] = model_object - keyword_args = { - key: value - for key, value in keyword_args.items() - if key not in model_dict.keys() - } - break + # If the 'obj' argument is not provided, raise an exception. + if obj is None: + raise ValueError(f"Missing required argument '{parameter_name}'.") - # Pass the model instance(s) to the wrapped function - return func(*bound_args.args, **keyword_args) + # Validate the 'obj' argument against the specified model. + model.model_validate(obj) + + # Call the wrapped function. + return func(*args, **kwargs) return wrapper