Spaces:
Running
Running
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
"""Example of invisibly propagating a request ID with middleware.""" | |
import argparse | |
import sys | |
import threading | |
import uuid | |
import pyarrow as pa | |
import pyarrow.flight as flight | |
class TraceContext: | |
_locals = threading.local() | |
_locals.trace_id = None | |
def current_trace_id(cls): | |
if not getattr(cls._locals, "trace_id", None): | |
cls.set_trace_id(uuid.uuid4().hex) | |
return cls._locals.trace_id | |
def set_trace_id(cls, trace_id): | |
cls._locals.trace_id = trace_id | |
TRACE_HEADER = "x-tracing-id" | |
class TracingServerMiddleware(flight.ServerMiddleware): | |
def __init__(self, trace_id): | |
self.trace_id = trace_id | |
def sending_headers(self): | |
return { | |
TRACE_HEADER: self.trace_id, | |
} | |
class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory): | |
def start_call(self, info, headers): | |
print("Starting new call:", info) | |
if TRACE_HEADER in headers: | |
trace_id = headers[TRACE_HEADER][0] | |
print("Found trace header with value:", trace_id) | |
TraceContext.set_trace_id(trace_id) | |
return TracingServerMiddleware(TraceContext.current_trace_id()) | |
class TracingClientMiddleware(flight.ClientMiddleware): | |
def sending_headers(self): | |
print("Sending trace ID:", TraceContext.current_trace_id()) | |
return { | |
"x-tracing-id": TraceContext.current_trace_id(), | |
} | |
def received_headers(self, headers): | |
if TRACE_HEADER in headers: | |
trace_id = headers[TRACE_HEADER][0] | |
print("Found trace header with value:", trace_id) | |
# Don't overwrite our trace ID | |
class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory): | |
def start_call(self, info): | |
print("Starting new call:", info) | |
return TracingClientMiddleware() | |
class FlightServer(flight.FlightServerBase): | |
def __init__(self, delegate, **kwargs): | |
super().__init__(**kwargs) | |
if delegate: | |
self.delegate = flight.connect( | |
delegate, | |
middleware=(TracingClientMiddlewareFactory(),)) | |
else: | |
self.delegate = None | |
def list_actions(self, context): | |
return [ | |
("get-trace-id", "Get the trace context ID."), | |
] | |
def do_action(self, context, action): | |
trace_middleware = context.get_middleware("trace") | |
if trace_middleware: | |
TraceContext.set_trace_id(trace_middleware.trace_id) | |
if action.type == "get-trace-id": | |
if self.delegate: | |
for result in self.delegate.do_action(action): | |
yield result | |
else: | |
trace_id = TraceContext.current_trace_id().encode("utf-8") | |
print("Returning trace ID:", trace_id) | |
buf = pa.py_buffer(trace_id) | |
yield pa.flight.Result(buf) | |
else: | |
raise KeyError(f"Unknown action {action.type!r}") | |
def main(): | |
parser = argparse.ArgumentParser() | |
subparsers = parser.add_subparsers(dest="command") | |
client = subparsers.add_parser("client", help="Run the client.") | |
client.add_argument("server") | |
client.add_argument("--request-id", default=None) | |
server = subparsers.add_parser("server", help="Run the server.") | |
server.add_argument( | |
"--listen", | |
required=True, | |
help="The location to listen on (example: grpc://localhost:5050)", | |
) | |
server.add_argument( | |
"--delegate", | |
required=False, | |
default=None, | |
help=("A location to delegate to. That is, this server will " | |
"simply call the given server for the response. Demonstrates " | |
"propagation of the trace ID between servers."), | |
) | |
args = parser.parse_args() | |
if not getattr(args, "command"): | |
parser.print_help() | |
return 1 | |
if args.command == "server": | |
server = FlightServer( | |
args.delegate, | |
location=args.listen, | |
middleware={"trace": TracingServerMiddlewareFactory()}) | |
server.serve() | |
elif args.command == "client": | |
client = flight.connect( | |
args.server, | |
middleware=(TracingClientMiddlewareFactory(),)) | |
if args.request_id: | |
TraceContext.set_trace_id(args.request_id) | |
else: | |
TraceContext.set_trace_id("client-chosen-id") | |
for result in client.do_action(flight.Action("get-trace-id", b"")): | |
print(result.body.to_pybytes()) | |
if __name__ == "__main__": | |
sys.exit(main() or 0) | |