implement NeonUser and NeonAPIKey classes

This commit is contained in:
2024-01-16 11:33:31 -05:00
parent 10368c78f1
commit a65283958b
9 changed files with 1298 additions and 777 deletions
+13 -4
View File
@@ -1,5 +1,9 @@
gen-model: fetch-v2-schema
datamodel-codegen --input v2.json --output model.py --use-standard-collections --output models.py \
datamodel-codegen \
--input v2.json \
--collapse-root-models \
--output neon_client/openapi_models.py \
--use-standard-collections \
--output-model-type pydantic_v2.BaseModel \
--input-file-type openapi \
--use-standard-collections \
@@ -8,12 +12,17 @@ gen-model: fetch-v2-schema
--use-schema-description \
--snake-case-field \
--enable-version-header \
--enum-field-as-literal one \
--use-double-quotes \
--field-constraints \
--allow-population-by-field-name \
--strict-nullable \
--use-title-as-name
--use-title-as-name \
--reuse-model \
--field-constraints \
--disable-appending-item-suffix \
--allow-extra-fields \
--use-annotated \
--capitalise-enum-members \
--use-unique-items-as-set
fetch-v2-schema:
curl -O https://neon.tech/api_spec/release/v2.json
+1
View File
@@ -8,6 +8,7 @@ openapi-python-client = "*"
pip = "*"
datamodel-code-generator = {extras = ["http"], version = "*"}
requests = "*"
fastapi = "*"
[dev-packages]
Generated
+19 -2
View File
@@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "704bfec6d465cd7b19bf01a611497b98b94d548ae602841c6d6f958663313a07"
"sha256": "1259587fdd4e8600832ae72610bbc1e18db0400b8d4d6cc9ea2c982df34cd285"
},
"pipfile-spec": 6,
"requires": {
@@ -303,6 +303,15 @@
],
"version": "==2.1.0.post1"
},
"fastapi": {
"hashes": [
"sha256:8c77515984cd8e8cfeb58364f8cc7a28f0692088475e2614f7bf03275eba9093",
"sha256:b978095b9ee01a5cf49b19f4bc1ac9b8ca83aa076e770ef8fd9af09a2b88d191"
],
"index": "pypi",
"markers": "python_version >= '3.8'",
"version": "==0.109.0"
},
"frozenlist": {
"hashes": [
"sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7",
@@ -652,7 +661,7 @@
"sha256:b3ef57c62535b0941697cce638c08900d87fcb67e29cfa99e8a68f747f393f7a",
"sha256:d0caf5954bee831b6bfe7e338c32b9e30c85dfe080c843680783ac2b631673b4"
],
"markers": "python_version >= '3.7'",
"markers": "python_version >= '3.12' and python_version < '4.0'",
"version": "==2.5.3"
},
"pydantic-core": {
@@ -886,6 +895,14 @@
"markers": "python_version >= '3.7'",
"version": "==1.3.0"
},
"starlette": {
"hashes": [
"sha256:3e2639dac3520e4f58734ed22553f950d3f3cb1001cd2eaac4d57e8cdc5f66bc",
"sha256:50bbbda9baa098e361f398fda0928062abbaf1f54f4fadcbe17c092a01eb9a25"
],
"markers": "python_version >= '3.8'",
"version": "==0.35.1"
},
"typer": {
"hashes": [
"sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2",
Binary file not shown.
+6
View File
@@ -5,3 +5,9 @@ class NeonClientException(HTTPError):
"""Base exception class for all exceptions raised by the Neon Client."""
pass
class NeonClientAuthenticationException(NeonClientException):
"""Exception raised when authentication fails."""
pass
+3 -4
View File
@@ -24,7 +24,7 @@ class Neon_API_V2:
response_model: BaseModel = None,
response_is_array=False,
check_status_code=True,
include_pagination=False,
_debug_pagination=False,
**kwargs,
):
"""
@@ -71,9 +71,8 @@ class Neon_API_V2:
else:
response_parsed = response_model(**r.json())
if include_pagination:
pagination = PaginationResponse(**r.json()).pagination
response_parsed.pagination = pagination
if _debug_pagination:
print(r.json())
return response_parsed
else:
File diff suppressed because it is too large Load Diff
+65 -38
View File
@@ -1,4 +1,5 @@
from typing import List
from fastapi.encoders import jsonable_encoder
from .http_client import Neon_API_V2
from .openapi_models import (
@@ -23,6 +24,11 @@ from .openapi_models import (
OperationResponse,
OperationsResponse,
PaginationResponse,
Branch,
Branch2,
Branch3,
BranchCreateRequestEndpointOptions,
BranchCreateRequest,
)
from .utils import validate_obj_model
@@ -33,6 +39,12 @@ class PagedOperationsResponse(OperationsResponse, PaginationResponse):
pass
class PagedProjectsResponse(ProjectsResponse, PaginationResponse):
"""A response containing a list of projects and pagination information."""
pass
class Resource:
base_path = None
@@ -44,15 +56,14 @@ class UserResource(Resource):
"""A resource for interacting with users."""
base_path = "users"
response_model = CurrentUserInfoResponse
def get_current_user_info(self):
"""Get information about the user."""
return self.api.request(
method="GET",
path=self.api.url_join(self.path, "me"),
response_model=self.response_model,
path=self.api.url_join(self.base_path, "me"),
response_model=CurrentUserInfoResponse,
)
@@ -95,43 +106,51 @@ class ProjectResource(Resource):
"""A resource for interacting with projects."""
base_path = "projects"
response_model = ProjectsResponse
response_model_single = ProjectResponse
def get_list(self, *, shared=False):
def get_list(
self,
*,
cursor: int | None = None,
limit: int | None = None,
shared: bool = False
):
"""Get a list of projects."""
project_response = self.api.request(
project_params = {}
if cursor is not None:
project_params["cursor"] = cursor
if limit is not None:
project_params["limit"] = limit
return self.api.request(
method="GET",
path=(
self.api.url_join(self.base_path, "shared")
if shared
else self.base_path
),
response_model=self.response_model,
params=project_params,
response_model=PagedProjectsResponse,
)
return project_response.projects
def get(self, project_id: str):
"""Get a project."""
project_response = self.api.request(
return self.api.request(
method="GET",
path=self.api.url_join(self.base_path, project_id),
response_model=self.response_model_single,
response_model=ProjectResponse,
)
return project_response.project
def create(self, name: str, **kwargs):
"""Create a new project."""
# def create(self, name: str, **kwargs):
# """Create a new project."""
project_create_response = self.api.request(
method="POST",
path=self.base_path,
json={"project": {"name": name, **kwargs}},
response_model=self.response_model_single,
)
return project_create_response.project
# return self.api.request(
# method="POST",
# path=self.base_path,
# json={"project": {"name": name, **kwargs}},
# response_model=ProjectResponse,
# ).model_dump()
def update(self, project: Project):
"""Update a project."""
@@ -142,7 +161,7 @@ class ProjectResource(Resource):
method="PATCH",
path=self.api.url_join(self.base_path, project.id),
json={"project": payload.model_dump()},
response_model=self.response_model_single,
response_model=ProjectResponse,
)
def delete(self, project_id: str):
@@ -151,14 +170,12 @@ class ProjectResource(Resource):
return self.api.request(
method="DELETE",
path=self.api.url_join(self.base_path, project_id),
response_model=self.response_model_single,
response_model=ProjectResponse,
)
class DatabaseResource(Resource):
base_path = "databases"
response_model = DatabasesResponse
response_model_single = DatabaseResponse
def _extract_database(self, obj):
"""Extract a database from the specified object."""
@@ -189,7 +206,7 @@ class DatabaseResource(Resource):
"projects", project_id, "branches", branch_id, "databases"
),
response_model=DatabasesResponse,
).model_dump()
)
def get(
self,
@@ -210,7 +227,7 @@ class DatabaseResource(Resource):
database_name,
),
response_model=DatabaseResponse,
).model_dump()
)
def create(
self,
@@ -227,9 +244,9 @@ class DatabaseResource(Resource):
path=self.api.url_join(
"projects", project_id, "branches", branch_id, "databases"
),
json=DatabaseCreateRequest(database=db).model_dump(),
json=jsonable_encoder(DatabaseCreateRequest(database=db)),
response_model=DatabaseResponse,
).model_dump()
)
def update(
self,
@@ -249,15 +266,13 @@ class DatabaseResource(Resource):
),
json=DatabaseUpdateRequest(database=db).model_dump(),
response_model=DatabaseResponse,
).model_dump()
)
class BranchResource(Resource):
"""A resource for interacting with branches."""
path = "branches"
response_model = BranchesResponse
response_model_single = BranchResponse
def get_list(self, project_id: str):
"""Get a list of branches."""
@@ -265,8 +280,8 @@ class BranchResource(Resource):
return self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "branches"),
response_model=self.response_model,
).model_dump()
response_model=BranchesResponse,
)
def get(self, project_id: str, branch_id: str):
"""Get a branch."""
@@ -274,8 +289,18 @@ class BranchResource(Resource):
return self.api.request(
method="GET",
path=self.api.url_join("projects", project_id, "branches", branch_id),
response_model=self.response_model_single,
).model_dump()
response_model=BranchResponse,
)
def create(self, project_id: str, request: BranchCreateRequest):
"""Create a new branch."""
return self.api.request(
method="POST",
path=self.api.url_join("projects", project_id, "branches"),
json=request.model_dump(),
response_model=BranchResponse,
)
class OperationResource(Resource):
@@ -302,7 +327,7 @@ class OperationResource(Resource):
path=self.api.url_join("projects", project_id, "operations"),
params=operations_params,
response_model=PagedOperationsResponse,
).model_dump()
)
def get(self, project_id: str, operation_id: str):
"""Get an operation."""
@@ -311,13 +336,15 @@ class OperationResource(Resource):
method="GET",
path=self.api.url_join("projects", project_id, "operations", operation_id),
response_model=OperationResponse,
).model_dump()
)
class ResourceCollection:
"""A collection of resources."""
def __init__(self, api: Neon_API_V2):
"""Initialize the collection."""
# Initialize resources.
self.api_keys = APIKeyResource(api)
self.users = UserResource(api)
+81 -4
View File
@@ -1,13 +1,90 @@
from .http_client import Neon_API_V2
from .resources import ResourceCollection
from . import openapi_models
class BaseNeonItem:
def __repr__(self):
return str(self)
class NeonUser(BaseNeonItem, openapi_models.CurrentUserAuthAccount):
@classmethod
def from_get_response(cls, r):
"""Create a NeonUser from an API response."""
# TODO: is this the right way to do this?
me = r.auth_accounts[0]
return cls.model_validate(me.model_dump())
def __str__(self):
return f"<NeonUser email={self.email}>"
class CollectionView:
def __init__(self, collection, key_ids=None):
if not key_ids:
key_ids = []
self._key_ids = key_ids
self._collection = collection
def __iter__(self):
return iter(self._collection)
def __getitem__(self, key):
for k in key_ids:
for item in self._collection:
if getattr(item, k) == key:
return item
return self._collection[key]
def __len__(self):
return len(self._collection)
def __repr__(self):
return repr(self._collection)
class NeonAPIKey(BaseNeonItem, openapi_models.ApiKeysListResponseItem):
@classmethod
def from_list_response(cls, r, *, neon):
"""Create a list of APIKeys from an API response."""
def gen():
for key in r:
k = cls.model_validate(key.model_dump())
k.neon = neon
yield k
return [g for g in gen()]
def __str__(self):
return f"<NeonAPIKey id={self.id}>"
def revoke(self, *, neon):
"""Revoke this API key."""
return bool(neon.resources.api_keys.revoke(self.id))
class NeonClient:
def __init__(self, api_key: str, **kwargs):
self.api = Neon_API_V2(api_key, **kwargs)
self.resources = ResourceCollection(self.api)
# self.api_keys = APIKeyResource(self.api)
# self.users = UserResource(self.api)
# self.projects = ProjectResource(self.api)
# self.databases = DatabaseResource(self.api)
@property
def me(self):
user = self.resources.users.get_current_user_info()
return NeonUser.from_get_response(user)
@property
def api_keys(self):
keys = self.resources.api_keys.get_list()
return CollectionView(
NeonAPIKey.from_list_response(keys, neon=self), key_ids=["id"]
)