119 lines
4.6 KiB
Python
119 lines
4.6 KiB
Python
import json
|
|
import inspect
|
|
import asyncio
|
|
from typing import Optional, Union, Iterable, Mapping
|
|
|
|
from .core import (
|
|
JSONRPC20Request, JSONRPC20BatchRequest, JSONRPC20Response,
|
|
JSONRPC20BatchResponse, JSONRPC20MethodNotFound, JSONRPC20InvalidParams,
|
|
JSONRPC20ServerError, JSONRPC20ParseError, JSONRPC20InvalidRequest,
|
|
JSONRPC20DispatchException,
|
|
)
|
|
from .dispatcher import Dispatcher
|
|
from .utils import is_invalid_params
|
|
|
|
|
|
class AsyncJSONRPCResponseManager:
|
|
|
|
"""Async JSON-RPC Response manager."""
|
|
|
|
def __init__(self, dispatcher: Dispatcher, serialize=json.dumps, deserialize=json.loads, is_server_error_verbose=False):
|
|
self.dispatcher = dispatcher
|
|
self.serialize = serialize
|
|
self.deserialize = deserialize
|
|
self.is_server_error_verbose = is_server_error_verbose
|
|
|
|
async def get_response_for_request(self, request: JSONRPC20Request) -> Optional[JSONRPC20Response]:
|
|
"""Get response for an individual request."""
|
|
output = None
|
|
response_id = request.id if not request.is_notification else None
|
|
try:
|
|
method = self.dispatcher[request.method]
|
|
except KeyError:
|
|
# method not found
|
|
output = JSONRPC20Response(
|
|
error=JSONRPC20MethodNotFound(),
|
|
id=response_id
|
|
)
|
|
else:
|
|
try:
|
|
result = await method(*request.args, **request.kwargs) \
|
|
if inspect.iscoroutinefunction(method) \
|
|
else method(*request.args, **request.kwargs)
|
|
except JSONRPC20DispatchException as dispatch_error:
|
|
# Dispatcher method raised exception with controlled "data"
|
|
output = JSONRPC20Response(
|
|
error=dispatch_error.error,
|
|
id=response_id
|
|
)
|
|
except Exception as e:
|
|
if is_invalid_params(method, *request.args, **request.kwargs):
|
|
# Method's parameters are incorrect
|
|
output = JSONRPC20Response(
|
|
error=JSONRPC20InvalidParams(),
|
|
id=response_id
|
|
)
|
|
else:
|
|
# Dispatcher method raised exception
|
|
output = JSONRPC20Response(
|
|
error=JSONRPC20ServerError(
|
|
data={
|
|
"type": e.__class__.__name__,
|
|
"args": e.args,
|
|
"message": str(e),
|
|
} if self.is_server_error_verbose else None
|
|
),
|
|
id=response_id
|
|
)
|
|
else:
|
|
output = JSONRPC20Response(result=result, id=response_id)
|
|
|
|
if not request.is_notification:
|
|
return output
|
|
|
|
async def get_response_for_request_body(self, request_body) -> Optional[JSONRPC20Response]:
|
|
"""Catch parse error as well"""
|
|
try:
|
|
request = JSONRPC20Request.from_body(request_body)
|
|
except ValueError:
|
|
return JSONRPC20Response(error=JSONRPC20InvalidRequest())
|
|
else:
|
|
return await self.get_response_for_request(request)
|
|
|
|
async def get_response_for_payload(self, payload: str) -> Optional[Union[JSONRPC20Response, JSONRPC20BatchResponse]]:
|
|
"""Top level handler
|
|
|
|
NOTE: top level handler, accepts string payload.
|
|
|
|
"""
|
|
try:
|
|
request_data = self.deserialize(payload)
|
|
except (TypeError, ValueError):
|
|
return JSONRPC20Response(error=JSONRPC20ParseError())
|
|
|
|
# check if iterable, and determine what request to instantiate.
|
|
is_batch_request = isinstance(request_data, Iterable) \
|
|
and not isinstance(request_data, Mapping)
|
|
if is_batch_request and len(request_data) == 0:
|
|
return JSONRPC20Response(error=JSONRPC20InvalidRequest())
|
|
|
|
requests_bodies = request_data if is_batch_request else [request_data]
|
|
responses = await asyncio.gather(*[
|
|
self.get_response_for_request_body(request_body)
|
|
for request_body in requests_bodies
|
|
])
|
|
nonempty_responses = [r for r in responses if r is not None]
|
|
if is_batch_request:
|
|
if len(nonempty_responses) > 0:
|
|
return JSONRPC20BatchResponse(nonempty_responses)
|
|
elif len(nonempty_responses) > 0:
|
|
return nonempty_responses[0]
|
|
|
|
async def get_payload_for_payload(self, payload: str) -> str:
|
|
response = await self.get_response_for_payload(payload)
|
|
|
|
if response is None:
|
|
return ""
|
|
|
|
return self.serialize(response.body)
|