Add pagination support to Neon_API_V2 class

This commit is contained in:
2024-01-14 17:27:28 -05:00
parent b547bb0941
commit 10368c78f1
3 changed files with 83 additions and 49 deletions
+10 -3
View File
@@ -24,6 +24,7 @@ class Neon_API_V2:
response_model: BaseModel = None,
response_is_array=False,
check_status_code=True,
include_pagination=False,
**kwargs,
):
"""
@@ -62,13 +63,19 @@ class Neon_API_V2:
if response_is_array:
# Shortcut for when the response is a list of items.
if type(response_is_array) == "str":
return [
response_parsed = [
response_model(**item) for item in r.json()[response_is_array]
]
elif response_is_array == True:
return [response_model(**item) for item in r.json()]
response_parsed = [response_model(**item) for item in r.json()]
else:
return response_model(**r.json())
response_parsed = response_model(**r.json())
if include_pagination:
pagination = PaginationResponse(**r.json()).pagination
response_parsed.pagination = pagination
return response_parsed
else:
return r
+73 -46
View File
@@ -20,10 +20,19 @@ from .openapi_models import (
DatabaseCreateRequest,
BranchesResponse,
BranchResponse,
OperationResponse,
OperationsResponse,
PaginationResponse,
)
from .utils import validate_obj_model
class PagedOperationsResponse(OperationsResponse, PaginationResponse):
"""A response containing a list of operations and pagination information."""
pass
class Resource:
base_path = None
@@ -50,17 +59,15 @@ class UserResource(Resource):
class APIKeyResource(Resource):
"""A resource for interacting with API keys."""
path = "api_keys"
response_model = ApiKeysListResponseItem
# response_model_single =
base_path = "api_keys"
def get_list(self):
"""Get a list of API keys."""
return self.api.request(
method="GET",
path=self.path,
response_model=self.response_model,
path=self.base_path,
response_model=ApiKeysListResponseItem,
response_is_array=True,
)
@@ -69,7 +76,7 @@ class APIKeyResource(Resource):
return self.api.request(
method="POST",
path=self.path,
path=self.base_path,
json=ApiKeyCreateRequest(key_name=key_name).model_dump(),
response_model=ApiKeyCreateResponse,
)
@@ -79,16 +86,7 @@ class APIKeyResource(Resource):
return self.api.request(
method="DELETE",
path=self.api.url_join(self.path, key_id),
response_model=ApiKeyRevokeResponse,
)
def get(self, key_id: str):
"""Get an API key."""
return self.api.request(
"GET",
f"api_keys/{key_id}",
path=self.api.url_join(self.base_path, str(key_id)),
response_model=ApiKeyRevokeResponse,
)
@@ -96,7 +94,7 @@ class APIKeyResource(Resource):
class ProjectResource(Resource):
"""A resource for interacting with projects."""
path = "projects"
base_path = "projects"
response_model = ProjectsResponse
response_model_single = ProjectResponse
@@ -105,7 +103,11 @@ class ProjectResource(Resource):
project_response = self.api.request(
method="GET",
path=(self.api.url_join(self.path, "shared") if shared else self.path),
path=(
self.api.url_join(self.base_path, "shared")
if shared
else self.base_path
),
response_model=self.response_model,
)
return project_response.projects
@@ -115,7 +117,7 @@ class ProjectResource(Resource):
project_response = self.api.request(
method="GET",
path=self.api.url_join(self.path, project_id),
path=self.api.url_join(self.base_path, project_id),
response_model=self.response_model_single,
)
return project_response.project
@@ -125,7 +127,7 @@ class ProjectResource(Resource):
project_create_response = self.api.request(
method="POST",
path=self.path,
path=self.base_path,
json={"project": {"name": name, **kwargs}},
response_model=self.response_model_single,
)
@@ -138,7 +140,7 @@ class ProjectResource(Resource):
return self.api.request(
method="PATCH",
path=self.api.url_join(self.path, project.id),
path=self.api.url_join(self.base_path, project.id),
json={"project": payload.model_dump()},
response_model=self.response_model_single,
)
@@ -148,20 +150,16 @@ class ProjectResource(Resource):
return self.api.request(
method="DELETE",
path=self.api.url_join(self.path, project_id),
path=self.api.url_join(self.base_path, project_id),
response_model=self.response_model_single,
)
class DatabaseResource(Resource):
path = "databases"
base_path = "databases"
response_model = DatabasesResponse
response_model_single = DatabaseResponse
CreateRequest = DatabaseCreateRequest
UpdateRequest = DatabaseUpdateRequest
Database = Database
def _extract_database(self, obj):
"""Extract a database from the specified object."""
@@ -185,15 +183,13 @@ class DatabaseResource(Resource):
See also: https://api-docs.neon.tech/reference/listprojectbranchdatabases
"""
databases_response = self.api.request(
return self.api.request(
method="GET",
path=self.api.url_join(
"projects", project_id, "branches", branch_id, "databases"
),
response_model=DatabasesResponse,
)
return databases_response.databases
).model_dump()
def get(
self,
@@ -203,7 +199,7 @@ class DatabaseResource(Resource):
):
"""Get a database."""
database_response = self.api.request(
return self.api.request(
method="GET",
path=self.api.url_join(
"projects",
@@ -214,8 +210,7 @@ class DatabaseResource(Resource):
database_name,
),
response_model=DatabaseResponse,
)
return database_response.database
).model_dump()
def create(
self,
@@ -227,15 +222,14 @@ class DatabaseResource(Resource):
db = self._extract_database(db)
database_create_response = self.api.request(
return self.api.request(
method="POST",
path=self.api.url_join(
"projects", project_id, "branches", branch_id, "databases"
),
json=DatabaseCreateRequest(database=db).model_dump(),
response_model=DatabaseResponse,
)
return database_create_response.database
).model_dump()
def update(
self,
@@ -248,15 +242,14 @@ class DatabaseResource(Resource):
db = self._extract_database(db)
database_update_response = self.api.request(
return self.api.request(
method="PATCH",
path=self.api.url_join(
"projects", project_id, "branches", branch_id, "databases", database_id
),
json=DatabaseUpdateRequest(database=db).model_dump(),
response_model=DatabaseResponse,
)
return database_update_response.database
).model_dump()
class BranchResource(Resource):
@@ -269,23 +262,56 @@ class BranchResource(Resource):
def get_list(self, project_id: str):
"""Get a list of branches."""
branches_response = self.api.request(
return self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "branches"),
response_model=self.response_model,
)
return branches_response.branches
).model_dump()
def get(self, project_id: str, branch_id: str):
"""Get a branch."""
branch_response = self.api.request(
return self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "branches", branch_id),
response_model=self.response_model_single,
)
).model_dump()
return branch_response.branch
class OperationResource(Resource):
"""A resource for interacting with operations."""
base_path = "operations"
def get_list(
self,
project_id: str,
cursor: int | None = None,
limit: int | None = None,
):
"""Get a list of operations."""
operations_params = {}
if cursor is not None:
operations_params["cursor"] = cursor
if limit is not None:
operations_params["limit"] = limit
return self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "operations"),
params=operations_params,
response_model=PagedOperationsResponse,
).model_dump()
def get(self, project_id: str, operation_id: str):
"""Get an operation."""
return self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "operations", operation_id),
response_model=OperationResponse,
).model_dump()
class ResourceCollection:
@@ -298,3 +324,4 @@ class ResourceCollection:
self.projects = ProjectResource(api)
self.databases = DatabaseResource(api)
self.branches = BranchResource(api)
self.operations = OperationResource(api)