Skip to content

Commit 0609d0c

Browse files
wrap observations in a MessageWrapper
1 parent cc29015 commit 0609d0c

2 files changed

Lines changed: 93 additions & 28 deletions

File tree

oshconnect/datasource/datasource.py

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
# Author: Ian Patterson <ian@botts-inc.com>
44
#
55
# Contact Email: ian@botts-inc.com
6+
from __future__ import annotations
67
import asyncio
8+
import json
79
from uuid import uuid4
810

911
import websockets
12+
from conSys4Py.datamodels.observations import ObservationOMJSONInline
1013

1114
from external_models.object_models import DatastreamResource
1215
from oshconnect import Utilities
@@ -60,37 +63,55 @@ def update_properties(self, properties: dict):
6063
# TODO: need to stop in progress sub-processes and restart
6164
self.properties = properties
6265

66+
def set_mode(self, mode: str):
67+
self.mode = mode
68+
6369
def initialize(self):
64-
pass
70+
if self._websocket.is_open():
71+
self._websocket.close()
72+
self._websocket = None
73+
self._status = "initialized"
6574

6675
async def connect(self):
6776
if self.mode == "websocket":
6877
self._websocket = await websockets.connect(self._url, extra_headers=self._extra_headers)
6978
self._status = "connected"
7079
return self._websocket
80+
elif self.mode == "playback":
81+
self._status = "connected"
82+
return "Playback mode is not yet implemented."
83+
elif self.mode == "live-batch":
84+
self._status = "connected"
85+
return "Live-batch mode is not yet implemented."
7186

87+
def disconnect(self):
88+
self._websocket.close()
7289

73-
def disconnect(self):
74-
pass
75-
76-
77-
def reset(self):
78-
pass
90+
def reset(self):
91+
self._websocket = None
92+
self._status = "initialized"
7993

94+
def get_status(self):
95+
return self._status
8096

81-
def get_status(self):
82-
return self.status
97+
def get_ws_client(self):
98+
return self._websocket
8399

84100

85101
class DataSourceHandler:
86102
datasource_map: dict[str, DataSource]
103+
_message_list: MessageHandler
87104

88105
def __init__(self):
89106
self.datasource_map = {}
107+
self._message_list = MessageHandler()
90108

91109
def add_datasource(self, datasource: DataSource):
92110
self.datasource_map[datasource.get_id()] = datasource
93111

112+
def remove_datasource(self, datasource_id: str):
113+
return self.datasource_map.pop(datasource_id)
114+
94115
def initialize_ds(self, datasource_id: str, properties: dict):
95116
ds = self.datasource_map.get(datasource_id)
96117
ds.initialize()
@@ -99,17 +120,76 @@ def initialize_all(self):
99120
# list comp is faster than for loop
100121
[ds.initialize() for ds in self.datasource_map.values()]
101122

123+
def set_ds_mode(self, mode: str):
124+
var = (ds.set_mode(mode) for ds in self.datasource_map.values())
125+
102126
async def connect_ds(self, datasource_id: str):
103127
ds = self.datasource_map.get(datasource_id)
104128
await ds.connect()
105129

106130
async def connect_all(self):
107-
results = await asyncio.gather(*(ds.connect() for ds in self.datasource_map.values()))
108-
return results
131+
# call connect for all datasources
132+
[(ds, await ds.connect()) for ds in self.datasource_map.values()]
133+
for ds in self.datasource_map.values():
134+
task = asyncio.create_task(self._handle_datastream_client(ds))
135+
# return task
109136

110137
def disconnect_ds(self, datasource_id: str):
111138
ds = self.datasource_map.get(datasource_id)
112139
ds.disconnect()
113140

114141
def disconnect_all(self):
115142
[ds.disconnect() for ds in self.datasource_map.values()]
143+
144+
async def _handle_datastream_client(self, datasource: DataSource):
145+
try:
146+
async for msg in datasource.get_ws_client():
147+
msg_dict = json.loads(msg.decode('utf-8'))
148+
obs = ObservationOMJSONInline.model_validate(msg_dict)
149+
msg_wrapper = MessageWrapper(datasource=datasource, message=obs)
150+
self._message_list.add_message(msg_wrapper)
151+
152+
except Exception as e:
153+
print(f"An error occurred while reading from websocket: {e}")
154+
155+
156+
class MessageHandler:
157+
_message_list: list[MessageWrapper]
158+
159+
def __init__(self):
160+
self._message_list = []
161+
162+
def add_message(self, message: MessageWrapper):
163+
self._message_list.append(message)
164+
print(self._message_list)
165+
166+
def get_messages(self):
167+
return self._message_list
168+
169+
def clear_messages(self):
170+
self._message_list.clear()
171+
172+
def sort_messages(self):
173+
# copy the list
174+
sorted_list = self._message_list.copy()
175+
sorted_list.sort(key=lambda x: x.resultTime)
176+
return sorted_list
177+
178+
179+
class MessageWrapper:
180+
"""
181+
Combines a DataSource and a Message into a single object for easier access
182+
"""
183+
184+
def __init__(self, datasource: DataSource, message: ObservationOMJSONInline):
185+
self._message = message
186+
self._datasource = datasource
187+
188+
def get_message(self):
189+
return self._message
190+
191+
def get_message_as_dict(self):
192+
return self._message.dict()
193+
194+
def __repr__(self):
195+
return f"{self._datasource}, {self._message}"

oshconnect/oshconnect.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,10 @@ def select_temporal_mode(self, mode: str):
8282

8383
async def playback_streams(self, stream_ids: list = None):
8484
if stream_ids is None:
85-
clients = await self._datasource_handler.connect_all()
86-
for client in clients:
87-
task = asyncio.create_task(self._handle_datastream_client(client))
88-
self._tasks.append(task)
85+
await self._datasource_handler.connect_all()
8986
else:
9087
for stream_id in stream_ids:
91-
clients = await self._datasource_handler.connect_ds(stream_id)
92-
for client in clients:
93-
msg = await client.recv()
94-
print(msg)
88+
await self._datasource_handler.connect_ds(stream_id)
9589

9690
def visualize_streams(self, streams: list):
9791
pass
@@ -128,12 +122,3 @@ def authenticate_user(self, user: dict):
128122

129123
def synchronize_streams(self, systems: list):
130124
pass
131-
132-
async def _handle_datastream_client(self, client):
133-
try:
134-
async for msg in client:
135-
msg_dict = json.loads(msg.decode('utf-8'))
136-
obs = ObservationOMJSONInline.model_validate(msg_dict)
137-
138-
except Exception as e:
139-
print(f"An error occurred while reading from websocket: {e}")

0 commit comments

Comments
 (0)