161 lines
4.6 KiB
Python
161 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
import pandas as pd
|
|
import plotly.graph_objects as go
|
|
from plotly.subplots import make_subplots
|
|
|
|
from .constants import TREND_BEAR, TREND_BULL, TREND_NEUTRAL
|
|
|
|
|
|
def _is_intraday_interval(interval: str) -> bool:
|
|
return interval in {"1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h"}
|
|
|
|
|
|
def _is_daily_interval(interval: str) -> bool:
|
|
return interval == "1d"
|
|
|
|
|
|
def _infer_session_bounds(df: pd.DataFrame) -> tuple[float, float] | None:
|
|
if df.empty:
|
|
return None
|
|
|
|
index = pd.DatetimeIndex(df.index)
|
|
if index.tz is None:
|
|
return None
|
|
|
|
minutes = index.hour * 60 + index.minute
|
|
session_df = pd.DataFrame({"date": index.date, "minute": minutes})
|
|
day_bounds = session_df.groupby("date")["minute"].agg(["min", "max"])
|
|
if day_bounds.empty:
|
|
return None
|
|
|
|
start_minute = float(day_bounds["min"].median())
|
|
# Include the final candle width roughly by adding one median step when possible.
|
|
if len(index) > 1:
|
|
deltas = pd.Series(index[1:] - index[:-1]).dt.total_seconds().div(60.0)
|
|
step = float(deltas[deltas > 0].median()) if not deltas[deltas > 0].empty else 0.0
|
|
else:
|
|
step = 0.0
|
|
end_minute = float(day_bounds["max"].median() + step)
|
|
|
|
return end_minute / 60.0, start_minute / 60.0
|
|
|
|
|
|
def build_figure(
|
|
df: pd.DataFrame,
|
|
gray_fake: bool,
|
|
*,
|
|
interval: str,
|
|
hide_market_closed_gaps: bool,
|
|
) -> go.Figure:
|
|
fig = make_subplots(
|
|
rows=2,
|
|
cols=1,
|
|
row_heights=[0.8, 0.2],
|
|
vertical_spacing=0.03,
|
|
shared_xaxes=True,
|
|
)
|
|
|
|
bull_mask = df["classification"] == "real_bull"
|
|
bear_mask = df["classification"] == "real_bear"
|
|
|
|
if gray_fake:
|
|
fig.add_trace(
|
|
go.Candlestick(
|
|
x=df.index,
|
|
open=df["Open"],
|
|
high=df["High"],
|
|
low=df["Low"],
|
|
close=df["Close"],
|
|
name="All Bars",
|
|
increasing_line_color="#B0B0B0",
|
|
decreasing_line_color="#808080",
|
|
opacity=0.35,
|
|
),
|
|
row=1,
|
|
col=1,
|
|
)
|
|
else:
|
|
fig.add_trace(
|
|
go.Candlestick(
|
|
x=df.index,
|
|
open=df["Open"],
|
|
high=df["High"],
|
|
low=df["Low"],
|
|
close=df["Close"],
|
|
name="All Bars",
|
|
increasing_line_color="#2E8B57",
|
|
decreasing_line_color="#B22222",
|
|
opacity=0.6,
|
|
),
|
|
row=1,
|
|
col=1,
|
|
)
|
|
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df.index[bull_mask],
|
|
y=df.loc[bull_mask, "Close"],
|
|
mode="markers",
|
|
name="Real Bullish",
|
|
marker=dict(color="#00C853", size=9, symbol="triangle-up"),
|
|
),
|
|
row=1,
|
|
col=1,
|
|
)
|
|
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df.index[bear_mask],
|
|
y=df.loc[bear_mask, "Close"],
|
|
mode="markers",
|
|
name="Real Bearish",
|
|
marker=dict(color="#D50000", size=9, symbol="triangle-down"),
|
|
),
|
|
row=1,
|
|
col=1,
|
|
)
|
|
|
|
trend_color = df["trend_state"].map(
|
|
{
|
|
TREND_BULL: "#00C853",
|
|
TREND_BEAR: "#D50000",
|
|
TREND_NEUTRAL: "#9E9E9E",
|
|
}
|
|
)
|
|
fig.add_trace(
|
|
go.Bar(
|
|
x=df.index,
|
|
y=df["Volume"],
|
|
marker_color=trend_color,
|
|
name="Volume",
|
|
opacity=0.65,
|
|
),
|
|
row=2,
|
|
col=1,
|
|
)
|
|
|
|
fig.update_layout(
|
|
template="plotly_white",
|
|
xaxis_rangeslider_visible=False,
|
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
|
|
margin=dict(l=20, r=20, t=60, b=20),
|
|
height=760,
|
|
)
|
|
if hide_market_closed_gaps:
|
|
rangebreaks: list[dict[str, object]] = [dict(bounds=["sat", "mon"])]
|
|
if _is_intraday_interval(interval):
|
|
# Collapse inferred overnight closed hours from the data's timezone/session.
|
|
inferred_bounds = _infer_session_bounds(df)
|
|
hour_bounds = list(inferred_bounds) if inferred_bounds else [16, 9.5]
|
|
rangebreaks.append(dict(pattern="hour", bounds=hour_bounds))
|
|
elif _is_daily_interval(interval):
|
|
# Daily charts still show weekend spacing on a continuous date axis.
|
|
# Weekend rangebreak removes these non-trading gaps.
|
|
pass
|
|
fig.update_xaxes(rangebreaks=rangebreaks)
|
|
|
|
fig.update_yaxes(title_text="Price", row=1, col=1)
|
|
fig.update_yaxes(title_text="Volume", row=2, col=1)
|
|
return fig
|