Skip to content

Commit 5f94077

Browse files
committed
Fix broken tracing functionality
1 parent 77eb9e9 commit 5f94077

4 files changed

Lines changed: 212 additions & 21 deletions

File tree

src/workflows/services/common_service.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,7 @@ def start_transport(self):
197197
otel_config = (
198198
self.config._opentelemetry if self.config and hasattr(self.config, "opentelemetry") else None
199199
)
200-
# debugging
201-
with open("/scratch/logs.txt", 'w+') as file:
202-
if otel_config:
203-
import json
204-
json.dump(otel_config, file, indent=4)
205-
else:
206-
file.write("otel config was not truthy")
207-
208-
if otel_config and "timeout" not in otel_config:
209-
self.log.warning("Missing optional OTEL configuration field `timeout`. Will default to 10 seconds. ")
210-
200+
if otel_config:
211201
# Configure OTELTracing
212202
resource = Resource.create(
213203
{

src/workflows/transport/common_transport.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
MessageCallback = Callable[[Mapping[str, Any], Any], None]
1212

13-
1413
class TemporarySubscription(NamedTuple):
1514
subscription_id: int
1615
queue_name: str

src/workflows/transport/middleware/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def wrapped_callback(header, message):
233233

234234

235235
def wrap(f: Callable):
236+
# debugging
237+
if f.__name__ == "send":
238+
print("we are wrapping send now")
239+
236240
@functools.wraps(f)
237241
def wrapper(self, *args, **kwargs):
238242
return functools.reduce(
@@ -243,4 +247,5 @@ def wrapper(self, *args, **kwargs):
243247
lambda *args, **kwargs: f(self, *args, **kwargs),
244248
)(*args, **kwargs)
245249

246-
return wrapper
250+
print(wrapper.__wrapped__)
251+
return wrapper

src/workflows/transport/middleware/otel_tracing.py

Lines changed: 205 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,228 @@
44
from collections.abc import Callable
55

66
from opentelemetry import trace
7-
from opentelemetry.propagate import extract
7+
from opentelemetry.propagate import extract, inject
8+
from opentelemetry.context import Context
89

910
from workflows.transport.middleware import BaseTransportMiddleware
11+
from workflows.transport.common_transport import TemporarySubscription, MessageCallback
12+
import json
1013

11-
12-
class OTELTracingMiddleware(BaseTransportMiddleware):
14+
class OTELTracingMiddleware:
1315
def __init__(self, tracer: trace.Tracer, service_name: str):
1416
self.tracer = tracer
1517
self.service_name = service_name
1618

17-
def subscribe(self, call_next: Callable, channel, callback, **kwargs) -> int:
19+
def send(self, call_next: Callable, destination: str, message: Any, **kwargs):
20+
# Get current span context (may be None if this is the root span)
21+
current_span = trace.get_current_span()
22+
parent_context = trace.set_span_in_context(current_span) if current_span else None
23+
24+
with self.tracer.start_as_current_span(
25+
"transport.send",
26+
context=parent_context,
27+
) as span:
28+
span.set_attribute("service_name", self.service_name)
29+
30+
span.set_attribute("message", json.dumps(message))
31+
span.set_attribute("destination", destination)
32+
print("parent_context is...",parent_context)
33+
34+
35+
# Inject the current trace context into the message headers
36+
headers = kwargs.get("headers", {})
37+
if headers is None:
38+
headers = {}
39+
inject(headers) # This modifies headers in-place
40+
kwargs["headers"] = headers
41+
42+
return call_next(destination, message, **kwargs)
43+
44+
def subscribe(self, call_next: Callable, channel: str, callback: Callable, **kwargs) -> int:
1845
@functools.wraps(callback)
1946
def wrapped_callback(header, message):
2047
# Extract trace context from message headers
21-
ctx = extract(header) if header else None
48+
ctx = extract(header) if header else Context()
2249

2350
# Start a new span with the extracted context
2451
with self.tracer.start_as_current_span(
25-
"transport.subscribe", context=ctx
52+
"transport.subscribe",
53+
context=ctx,
54+
) as span:
55+
span.set_attribute("service_name", self.service_name)
56+
57+
span.set_attribute("message", json.dumps(message))
58+
span.set_attribute("channel", channel)
59+
60+
# Call the original callback - this will process the message
61+
# and potentially call send() which will pick up this context
62+
return callback(header, message)
63+
64+
return call_next(channel, wrapped_callback, **kwargs)
65+
66+
def subscribe_broadcast(self, call_next: Callable, channel: str, callback: Callable, **kwargs) -> int:
67+
@functools.wraps(callback)
68+
def wrapped_callback(header, message):
69+
# Extract trace context from message headers
70+
ctx = extract(header) if header else Context()
71+
72+
# # Start a new span with the extracted context
73+
with self.tracer.start_as_current_span(
74+
"transport.subscribe_broadcast",
75+
context=ctx,
2676
) as span:
2777
span.set_attribute("service_name", self.service_name)
78+
79+
span.set_attribute("message", json.dumps(message))
2880
span.set_attribute("channel", channel)
2981

30-
# Call the original callback
3182
return callback(header, message)
3283

33-
# Call the next middleware with the wrapped callback
3484
return call_next(channel, wrapped_callback, **kwargs)
85+
86+
def subscribe_temporary(
87+
self,
88+
call_next: Callable,
89+
channel_hint: str | None,
90+
callback: MessageCallback,
91+
**kwargs,
92+
) -> TemporarySubscription:
93+
@functools.wraps(callback)
94+
def wrapped_callback(header, message):
95+
# Extract trace context from message headers
96+
ctx = extract(header) if header else Context()
97+
98+
# Start a new span with the extracted context
99+
with self.tracer.start_as_current_span(
100+
"transport.subscribe_temporary",
101+
context=ctx,
102+
) as span:
103+
span.set_attribute("service_name", self.service_name)
104+
105+
span.set_attribute("message", json.dumps(message))
106+
if channel_hint:
107+
span.set_attribute("channel_hint", channel_hint)
108+
109+
return callback(header, message)
110+
111+
return call_next(channel_hint, wrapped_callback, **kwargs)
112+
113+
def unsubscribe(
114+
self,
115+
call_next: Callable,
116+
subscription: int,
117+
drop_callback_reference=False,
118+
**kwargs,
119+
):
120+
# Get current span context
121+
current_span = trace.get_current_span()
122+
current_context = trace.set_span_in_context(current_span) if current_span else Context()
123+
124+
with self.tracer.start_as_current_span(
125+
"transport.unsubscribe",
126+
context=current_context,
127+
) as span:
128+
span.set_attribute("service_name", self.service_name)
129+
span.set_attribute("subscription_id", subscription)
130+
131+
call_next(
132+
subscription, drop_callback_reference=drop_callback_reference, **kwargs
133+
)
134+
135+
def ack(
136+
self,
137+
call_next: Callable,
138+
message,
139+
subscription_id: int | None = None,
140+
**kwargs,
141+
):
142+
# Get current span context
143+
current_span = trace.get_current_span()
144+
current_context = trace.set_span_in_context(current_span) if current_span else Context()
145+
146+
with self.tracer.start_as_current_span(
147+
"transport.ack",
148+
context=current_context,
149+
) as span:
150+
span.set_attribute("service_name", self.service_name)
151+
span.set_attribute("message", json.dumps(message))
152+
if subscription_id:
153+
span.set_attribute("subscription_id", subscription_id)
154+
155+
call_next(message, subscription_id=subscription_id, **kwargs)
156+
157+
def nack(
158+
self,
159+
call_next: Callable,
160+
message,
161+
subscription_id: int | None = None,
162+
**kwargs,
163+
):
164+
# Get current span context
165+
current_span = trace.get_current_span()
166+
current_context = trace.set_span_in_context(current_span) if current_span else Context()
167+
168+
with self.tracer.start_as_current_span(
169+
"transport.nack",
170+
context=current_context,
171+
) as span:
172+
span.set_attribute("service_name", self.service_name)
173+
174+
span.set_attribute("message", json.dumps(message))
175+
if subscription_id:
176+
span.set_attribute("subscription_id", subscription_id)
177+
178+
call_next(message, subscription_id=subscription_id, **kwargs)
179+
180+
def transaction_begin(
181+
self, call_next: Callable, subscription_id: int | None = None, **kwargs
182+
) -> int:
183+
"""Start a new transaction span"""
184+
# Get current span context (may be None if this is the root span)
185+
current_span = trace.get_current_span()
186+
current_context = trace.set_span_in_context(current_span) if current_span else Context()
187+
188+
with self.tracer.start_as_current_span(
189+
"transaction.begin",
190+
context=current_context,
191+
) as span:
192+
span.set_attribute("service_name", self.service_name)
193+
194+
if subscription_id:
195+
span.set_attribute("subscription_id", subscription_id)
196+
197+
return call_next(subscription_id=subscription_id, **kwargs)
198+
199+
def transaction_abort(self, call_next: Callable, transaction_id: int | None = None, **kwargs):
200+
"""Abort a transaction span"""
201+
# Get current span context
202+
current_span = trace.get_current_span()
203+
current_context = trace.set_span_in_context(current_span) if current_span else Context()
204+
205+
with self.tracer.start_as_current_span(
206+
"transaction.abort",
207+
context=current_context,
208+
) as span:
209+
span.set_attribute("service_name", self.service_name)
210+
211+
if transaction_id:
212+
span.set_attribute("transaction_id", transaction_id)
213+
214+
call_next(transaction_id=transaction_id, **kwargs)
215+
216+
def transaction_commit(self, call_next: Callable, transaction_id: int | None = None, **kwargs):
217+
"""Commit a transaction span"""
218+
# Get current span context
219+
current_span = trace.get_current_span()
220+
current_context = trace.set_span_in_context(current_span) if current_span else Context()
221+
222+
with self.tracer.start_as_current_span(
223+
"transaction.commit",
224+
context=current_context,
225+
) as span:
226+
span.set_attribute("service_name", self.service_name)
227+
if transaction_id:
228+
span.set_attribute("transaction_id", transaction_id)
229+
230+
call_next(transaction_id=transaction_id, **kwargs)
231+

0 commit comments

Comments
 (0)