336 lines
12 KiB
Python
336 lines
12 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
from typing import List, Optional, Callable
|
|
|
|
from websockets.sync.client import connect as sync_connect
|
|
from websockets.asyncio.client import connect as async_connect
|
|
|
|
from yfinance import utils
|
|
from yfinance.pricing_pb2 import PricingData
|
|
from google.protobuf.json_format import MessageToDict
|
|
|
|
|
|
class BaseWebSocket:
|
|
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
|
|
self.url = url
|
|
self.verbose = verbose
|
|
self.logger = utils.get_yf_logger()
|
|
self._ws = None
|
|
self._subscriptions = set()
|
|
self._subscription_interval = 15 # seconds
|
|
|
|
def _decode_message(self, base64_message: str) -> dict:
|
|
try:
|
|
decoded_bytes = base64.b64decode(base64_message)
|
|
pricing_data = PricingData()
|
|
pricing_data.ParseFromString(decoded_bytes)
|
|
return MessageToDict(pricing_data, preserving_proto_field_name=True)
|
|
except Exception as e:
|
|
self.logger.error("Failed to decode message: %s", e, exc_info=True)
|
|
if self.verbose:
|
|
print("Failed to decode message: %s", e)
|
|
return {
|
|
'error': str(e),
|
|
'raw_base64': base64_message
|
|
}
|
|
|
|
|
|
class AsyncWebSocket(BaseWebSocket):
|
|
"""
|
|
Asynchronous WebSocket client for streaming real time pricing data.
|
|
"""
|
|
|
|
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
|
|
"""
|
|
Initialize the AsyncWebSocket client.
|
|
|
|
Args:
|
|
url (str): The WebSocket server URL. Defaults to Yahoo Finance's WebSocket URL.
|
|
verbose (bool): Flag to enable or disable print statements. Defaults to True.
|
|
"""
|
|
super().__init__(url, verbose)
|
|
self._message_handler = None # Callable to handle messages
|
|
self._heartbeat_task = None # Task to send heartbeat subscribe
|
|
|
|
async def _connect(self):
|
|
try:
|
|
if self._ws is None:
|
|
self._ws = await async_connect(self.url)
|
|
self.logger.info("Connected to WebSocket.")
|
|
if self.verbose:
|
|
print("Connected to WebSocket.")
|
|
except Exception as e:
|
|
self.logger.error("Failed to connect to WebSocket: %s", e, exc_info=True)
|
|
if self.verbose:
|
|
print(f"Failed to connect to WebSocket: {e}")
|
|
self._ws = None
|
|
raise
|
|
|
|
async def _periodic_subscribe(self):
|
|
while True:
|
|
try:
|
|
await asyncio.sleep(self._subscription_interval)
|
|
|
|
if self._subscriptions:
|
|
message = {"subscribe": list(self._subscriptions)}
|
|
await self._ws.send(json.dumps(message))
|
|
|
|
if self.verbose:
|
|
print(f"Heartbeat subscription sent for symbols: {self._subscriptions}")
|
|
except Exception as e:
|
|
self.logger.error("Error in heartbeat subscription: %s", e, exc_info=True)
|
|
if self.verbose:
|
|
print(f"Error in heartbeat subscription: {e}")
|
|
break
|
|
|
|
async def subscribe(self, symbols: str | List[str]):
|
|
"""
|
|
Subscribe to a stock symbol or a list of stock symbols.
|
|
|
|
Args:
|
|
symbols (str | List[str]): Stock symbol(s) to subscribe to.
|
|
"""
|
|
await self._connect()
|
|
|
|
if isinstance(symbols, str):
|
|
symbols = [symbols]
|
|
|
|
self._subscriptions.update(symbols)
|
|
|
|
message = {"subscribe": list(self._subscriptions)}
|
|
await self._ws.send(json.dumps(message))
|
|
|
|
# Start heartbeat subscription task
|
|
if self._heartbeat_task is None:
|
|
self._heartbeat_task = asyncio.create_task(self._periodic_subscribe())
|
|
|
|
self.logger.info(f"Subscribed to symbols: {symbols}")
|
|
if self.verbose:
|
|
print(f"Subscribed to symbols: {symbols}")
|
|
|
|
async def unsubscribe(self, symbols: str | List[str]):
|
|
"""
|
|
Unsubscribe from a stock symbol or a list of stock symbols.
|
|
|
|
Args:
|
|
symbols (str | List[str]): Stock symbol(s) to unsubscribe from.
|
|
"""
|
|
await self._connect()
|
|
|
|
if isinstance(symbols, str):
|
|
symbols = [symbols]
|
|
|
|
self._subscriptions.difference_update(symbols)
|
|
|
|
message = {"unsubscribe": symbols}
|
|
await self._ws.send(json.dumps(message))
|
|
|
|
self.logger.info(f"Unsubscribed from symbols: {symbols}")
|
|
if self.verbose:
|
|
print(f"Unsubscribed from symbols: {symbols}")
|
|
|
|
async def listen(self, message_handler=None):
|
|
"""
|
|
Start listening to messages from the WebSocket server.
|
|
|
|
Args:
|
|
message_handler (Optional[Callable[[dict], None]]): Optional function to handle received messages.
|
|
"""
|
|
await self._connect()
|
|
self._message_handler = message_handler
|
|
|
|
self.logger.info("Listening for messages...")
|
|
if self.verbose:
|
|
print("Listening for messages...")
|
|
|
|
# Start heartbeat subscription task
|
|
if self._heartbeat_task is None:
|
|
self._heartbeat_task = asyncio.create_task(self._periodic_subscribe())
|
|
|
|
while True:
|
|
try:
|
|
async for message in self._ws:
|
|
message_json = json.loads(message)
|
|
encoded_data = message_json.get("message", "")
|
|
decoded_message = self._decode_message(encoded_data)
|
|
|
|
if self._message_handler:
|
|
try:
|
|
if asyncio.iscoroutinefunction(self._message_handler):
|
|
await self._message_handler(decoded_message)
|
|
else:
|
|
self._message_handler(decoded_message)
|
|
except Exception as handler_exception:
|
|
self.logger.error("Error in message handler: %s", handler_exception, exc_info=True)
|
|
if self.verbose:
|
|
print("Error in message handler:", handler_exception)
|
|
else:
|
|
print(decoded_message)
|
|
|
|
except (KeyboardInterrupt, asyncio.CancelledError):
|
|
self.logger.info("WebSocket listening interrupted. Closing connection...")
|
|
if self.verbose:
|
|
print("WebSocket listening interrupted. Closing connection...")
|
|
await self.close()
|
|
break
|
|
|
|
except Exception as e:
|
|
self.logger.error("Error while listening to messages: %s", e, exc_info=True)
|
|
if self.verbose:
|
|
print("Error while listening to messages: %s", e)
|
|
|
|
# Attempt to reconnect if connection drops
|
|
self.logger.info("Attempting to reconnect...")
|
|
if self.verbose:
|
|
print("Attempting to reconnect...")
|
|
await asyncio.sleep(3) # backoff
|
|
await self._connect()
|
|
|
|
async def close(self):
|
|
"""Close the WebSocket connection."""
|
|
if self._heartbeat_task:
|
|
self._heartbeat_task.cancel()
|
|
|
|
if self._ws is not None: # and not self._ws.closed:
|
|
await self._ws.close()
|
|
self.logger.info("WebSocket connection closed.")
|
|
if self.verbose:
|
|
print("WebSocket connection closed.")
|
|
|
|
async def __aenter__(self):
|
|
await self._connect()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
await self.close()
|
|
|
|
|
|
class WebSocket(BaseWebSocket):
|
|
"""
|
|
Synchronous WebSocket client for streaming real time pricing data.
|
|
"""
|
|
|
|
def __init__(self, url: str = "wss://streamer.finance.yahoo.com/?version=2", verbose=True):
|
|
"""
|
|
Initialize the WebSocket client.
|
|
|
|
Args:
|
|
url (str): The WebSocket server URL. Defaults to Yahoo Finance's WebSocket URL.
|
|
verbose (bool): Flag to enable or disable print statements. Defaults to True.
|
|
"""
|
|
super().__init__(url, verbose)
|
|
|
|
def _connect(self):
|
|
try:
|
|
if self._ws is None:
|
|
self._ws = sync_connect(self.url)
|
|
self.logger.info("Connected to WebSocket.")
|
|
if self.verbose:
|
|
print("Connected to WebSocket.")
|
|
except Exception as e:
|
|
self.logger.error("Failed to connect to WebSocket: %s", e, exc_info=True)
|
|
if self.verbose:
|
|
print(f"Failed to connect to WebSocket: {e}")
|
|
self._ws = None
|
|
raise
|
|
|
|
def subscribe(self, symbols: str | List[str]):
|
|
"""
|
|
Subscribe to a stock symbol or a list of stock symbols.
|
|
|
|
Args:
|
|
symbols (str | List[str]): Stock symbol(s) to subscribe to.
|
|
"""
|
|
self._connect()
|
|
|
|
if isinstance(symbols, str):
|
|
symbols = [symbols]
|
|
|
|
self._subscriptions.update(symbols)
|
|
|
|
message = {"subscribe": list(self._subscriptions)}
|
|
self._ws.send(json.dumps(message))
|
|
|
|
self.logger.info(f"Subscribed to symbols: {symbols}")
|
|
if self.verbose:
|
|
print(f"Subscribed to symbols: {symbols}")
|
|
|
|
def unsubscribe(self, symbols: str | List[str]):
|
|
"""
|
|
Unsubscribe from a stock symbol or a list of stock symbols.
|
|
|
|
Args:
|
|
symbols (str | List[str]): Stock symbol(s) to unsubscribe from.
|
|
"""
|
|
self._connect()
|
|
|
|
if isinstance(symbols, str):
|
|
symbols = [symbols]
|
|
|
|
self._subscriptions.difference_update(symbols)
|
|
|
|
message = {"unsubscribe": symbols}
|
|
self._ws.send(json.dumps(message))
|
|
|
|
self.logger.info(f"Unsubscribed from symbols: {symbols}")
|
|
if self.verbose:
|
|
print(f"Unsubscribed from symbols: {symbols}")
|
|
|
|
def listen(self, message_handler: Optional[Callable[[dict], None]] = None):
|
|
"""
|
|
Start listening to messages from the WebSocket server.
|
|
|
|
Args:
|
|
message_handler (Optional[Callable[[dict], None]]): Optional function to handle received messages.
|
|
"""
|
|
self._connect()
|
|
|
|
self.logger.info("Listening for messages...")
|
|
if self.verbose:
|
|
print("Listening for messages...")
|
|
|
|
while True:
|
|
try:
|
|
message = self._ws.recv()
|
|
message_json = json.loads(message)
|
|
encoded_data = message_json.get("message", "")
|
|
decoded_message = self._decode_message(encoded_data)
|
|
|
|
if message_handler:
|
|
try:
|
|
message_handler(decoded_message)
|
|
except Exception as handler_exception:
|
|
self.logger.error("Error in message handler: %s", handler_exception, exc_info=True)
|
|
if self.verbose:
|
|
print("Error in message handler:", handler_exception)
|
|
else:
|
|
print(decoded_message)
|
|
|
|
except KeyboardInterrupt:
|
|
if self.verbose:
|
|
print("Received keyboard interrupt.")
|
|
self.close()
|
|
break
|
|
|
|
except Exception as e:
|
|
self.logger.error("Error while listening to messages: %s", e, exc_info=True)
|
|
if self.verbose:
|
|
print("Error while listening to messages: %s", e)
|
|
break
|
|
|
|
def close(self):
|
|
"""Close the WebSocket connection."""
|
|
if self._ws is not None:
|
|
self._ws.close()
|
|
self.logger.info("WebSocket connection closed.")
|
|
if self.verbose:
|
|
print("WebSocket connection closed.")
|
|
|
|
def __enter__(self):
|
|
self._connect()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close()
|