Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions pyk/src/pyk/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import TYPE_CHECKING, Any, Final, NamedTuple
from typing import TYPE_CHECKING, NamedTuple

from typing_extensions import Protocol

Expand All @@ -15,6 +16,7 @@
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from typing import Any, Final


_LOGGER: Final = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,7 +88,7 @@ class JsonRpcBatchRequest(NamedTuple):
class JsonRpcResult(ABC):

@abstractmethod
def encode(self) -> bytes: ...
def encode(self) -> Iterator[bytes]: ...


@dataclass(frozen=True)
Expand All @@ -96,7 +98,7 @@ class JsonRpcError(JsonRpcResult):
message: str
id: str | int | None

def to_json(self) -> dict[str, Any]:
def wrap_response(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'error': {
Expand All @@ -106,32 +108,39 @@ def to_json(self) -> dict[str, Any]:
'id': self.id,
}

def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')
def encode(self) -> Iterator[bytes]:
yield json.dumps(self.wrap_response()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcSuccess(JsonRpcResult):
payload: Any
id: Any

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'result': self.payload,
'id': self.id,
}

def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')
def encode(self) -> Iterator[bytes]:
yield f'{{"jsonrpc":"2.0", "id": {self.id}, "result": '.encode('ascii')
if isinstance(self.payload, Iterator):
for chunk in self.payload:
yield chunk.encode('ascii')
else:
yield json.dumps(self.payload).encode('ascii')
yield b'}'


@dataclass(frozen=True)
class JsonRpcBatchResult(JsonRpcResult):
results: tuple[JsonRpcError | JsonRpcSuccess, ...]

def encode(self) -> bytes:
return json.dumps([result.to_json() for result in self.results]).encode('ascii')
def encode(self) -> Iterator[bytes]:
yield b'['
first = True
for result in self.results:
if not first:
yield b','
else:
first = False
yield from result.encode()
yield b']'


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
Expand All @@ -143,8 +152,10 @@ def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any)

def _send_response(self, response: JsonRpcResult) -> None:
self.send_response_headers()
response_bytes = response.encode()
self.wfile.write(response_bytes)
response_body = response.encode()
for chunk in response_body:
self.wfile.write(chunk)
self.wfile.flush()

def send_response_headers(self) -> None:
self.send_response(200)
Expand Down
10 changes: 10 additions & 0 deletions pyk/src/tests/integration/test_json_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyk.testing import KRunTest

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Any


Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(self, options: ServeRpcOptions) -> None:
self.register_method('set_x', self.exec_set_x)
self.register_method('set_y', self.exec_set_y)
self.register_method('add', self.exec_add)
self.register_method('streaming', self.exec_streaming)

def exec_get_x(self) -> int:
return self.x
Expand All @@ -170,6 +172,11 @@ def exec_set_y(self, n: int) -> None:
def exec_add(self) -> int:
return self.x + self.y

def exec_streaming(self) -> Iterator[bytes]:
yield b'{'
yield b'"foo": "bar"'
yield b'}'


class TestJsonRPCServer(KRunTest):

Expand Down Expand Up @@ -221,6 +228,9 @@ def wait_until_ready() -> None:
assert len(res) == 3
assert res[2]['result'] == 1 + 2

res = rpc_client.request('streaming', [])
assert res == {'foo': 'bar'}

server.shutdown()
thread.join()

Expand Down
Loading