From 735abfaf598598349d627ecb74f6090c8d7cfcad Mon Sep 17 00:00:00 2001 From: Stephan Lee Date: Wed, 20 Oct 2021 10:38:41 -0700 Subject: [PATCH] projector: add notebook renderer This change introduces Python API to render the new projector in Juypter notebook. It currently provides work in progress API, `embedding` which visualizes the embedding data and label rendered in the visualization. In order to support both Colab and Jupyter, I have decided to introduce an abstraction, `renderer` which knows how to render an output cell and how to communicate with it. Do note that Jupyter and Colab are very different; Jupyter does not encapsulate JavaScript context for each output cell so a symbol exposed on `globalThis` bleeds over to other output cells while Colab does not do that. --- .../visualization/projector_v2/notebook.py | 98 ++++++++++++ .../visualization/projector_v2/renderer.py | 143 ++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 tensorflow_similarity/visualization/projector_v2/notebook.py create mode 100644 tensorflow_similarity/visualization/projector_v2/renderer.py diff --git a/tensorflow_similarity/visualization/projector_v2/notebook.py b/tensorflow_similarity/visualization/projector_v2/notebook.py new file mode 100644 index 00000000..7fb70f30 --- /dev/null +++ b/tensorflow_similarity/visualization/projector_v2/notebook.py @@ -0,0 +1,98 @@ +# Copyright 2021 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import random +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +from tensorflow_similarity.visualization.projector_v2 import renderer + + +def _get_renderer(): + # In Colab, the `google.colab` module is available, but the shell + # returned by `IPython.get_ipython` does not have a `get_trait` + # method. + try: + import google.colab # noqa: F401 + import IPython + except ImportError: + pass + else: + if IPython.get_ipython() is not None: + # We'll assume that we're in a Colab notebook context. + raise NotImplementedError("Colab support not implemented") + + # In an IPython command line shell or Jupyter notebook, we can + # directly query whether we're in a notebook c ontext. + try: + import IPython + except ImportError: + pass + else: + ipython = IPython.get_ipython() + if ipython is not None and ipython.has_trait("kernel"): + return renderer.IPythonRenderer() + + # Otherwise, we're not in a known notebook context. + raise NotImplementedError("Must use the tool under a notebook context.") + + +def embedding( + embeddings: Sequence[Union[Tuple[float, float], Tuple[float, float, float]]], + labels: Optional[Sequence[Union[str, int]]] = None, + image_labels: Optional[Sequence[Union[str, int]]] = None, +): + """ """ + if isinstance(embeddings, np.ndarray): + embeddings = embeddings.tolist() + + cur_renderer = _get_renderer() + handle = cur_renderer.display() + cur_renderer.send_message( + handle, + "update", + embeddings, + {"step": 0, "maxStep": 0}, + ) + + if labels: + label_to_color = dict() + metadata = [] + for label, image in itertools.zip_longest( + labels, image_labels or [], fillvalue="" + ): + color = label_to_color.get(label, None) + if color is None: + red = random.randrange(0, 0xFF) + green = random.randrange(0, 0xFF) + blue = random.randrange(0, 0xFF) + color = red * 2 ** 16 + green * 2 ** 8 + blue + label_to_color.setdefault(label, color) + + print(label, color) + metadata.append( + { + "label": label, + "color": color, + "imageLabel": image, + } + ) + + cur_renderer.send_message( + handle, + "meta", + metadata, + {"algo": "custom"}, + ) diff --git a/tensorflow_similarity/visualization/projector_v2/renderer.py b/tensorflow_similarity/visualization/projector_v2/renderer.py new file mode 100644 index 00000000..fc99ffa4 --- /dev/null +++ b/tensorflow_similarity/visualization/projector_v2/renderer.py @@ -0,0 +1,143 @@ +# Copyright 2021 The TensorFlow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import functools +import json +import random +import string +import threading +from http import server +from os import path +from typing import Any, Dict, Optional, Sequence, Union + + +class Renderer(metaclass=abc.ABCMeta): + @abc.abstractmethod + def display(self): + pass + + @abc.abstractmethod + def send_message( + self, + cell_id: str, + msg_type: str, + payload: Sequence[Any], + other_payload: Optional[Dict[str, Union[str, int, float]]], + ): + pass + + +class IPythonRenderer(Renderer): + class _IPythonRequestHandler(server.BaseHTTPRequestHandler): + def __init__(self, *args, callback=None, **kwargs): + self._cb = callback + super(IPythonRenderer._IPythonRequestHandler, self).__init__( + *args, **kwargs + ) + + def do_OPTIONS(self): + self.send_response(200) + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Access-Control-Request-Method", "POST") + self.send_header("Access-Control-Allow-Headers", "Content-Type") + self.end_headers() + self.wfile.write("OK".encode("utf8")) + + def do_POST(self): + if self._cb: + content_length = int(self.headers["Content-Length"]) + form_content = self.rfile.read(content_length) + payload = json.loads(form_content) + + self._cb(payload) + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + self.wfile.write("OK".encode("utf8")) + + def __init__(self): + self._server = server.ThreadingHTTPServer( + ("", 0), + functools.partial( + IPythonRenderer._IPythonRequestHandler, callback=self._on_req + ), + ) + thread = threading.Thread( + target=self._server.serve_forever, + ) + thread.daemon = True + thread.start() + + def _generate_id(self): + return "".join(random.choices(string.ascii_letters, k=12)) + + def display(self): + from IPython import display + + display.clear_output() + with open(path.join(path.dirname(__file__), "bin", "index.js"), "r") as f: + library = display.Javascript(f.read()) + + unique_id = self._generate_id() + container = display.HTML( + f'
' + ) + bootstrap = display.Javascript( + """ + globalThis.messenger.initForIPython(%d); + globalThis.messenger.createMessengerForOutputcell("%s"); + globalThis.bootstrap("%s"); + """ + % (self._server.server_port, unique_id, unique_id) + ) + display.display( + library, + container, + ) + display.display( + bootstrap, + ) + + return unique_id + + def send_message( + self, + cell_id: str, + msg_type: str, + payload: Sequence[Any], + other_payload: Optional[Dict[str, Union[str, int, float]]] = {}, + ): + from IPython import display + + unique_id = self._generate_id() + send_payload = display.Javascript( + f""" + globalThis["{unique_id}"] = {json.dumps(payload)}; + """ + ) + load_payload = display.Javascript( + f""" + globalThis.messenger.onMessageFromPython( + "{cell_id}", + "{msg_type}", + ["{unique_id}"], + {json.dumps(other_payload)} + ); + """ + ) + display.display(send_payload) + display.display(load_payload)