import { defaultGraphMargins } from '../../CONSTANTS';
import ChartLeftLabel from '../../common/ChartLeftLabel';
import ChartTopLabel from '../../common/ChartTopLabel';
import {
	StyledAxisGroup,
	StyledSVGContainer,
} from '../../common/styledComponents';
import useLinearXScale from '../../hooks/useLinearXScale';
import useLinearYAxis from '../../hooks/useLinearYAxis';
import useLinearYScale from '../../hooks/useLinearYScale';
import useXAxis from '../../hooks/useXAxis';
import { LeafArray, PointDrawFn } from '../../types';
import { nanoid } from '@reduxjs/toolkit';
import useElementSize from 'common/hooks/useSize';
import { FunctionComponent, ReactElement, useMemo } from 'react';
import styled, { CSSProperties } from 'styled-components';

const StyledPoint = styled.circle`
	fill: ${(p) => p.theme.palette.primary.main};
`;

const drawDefaultPoint: PointDrawFn = ({ drawX, drawY }) => (
	<StyledPoint
		cx={drawX}
		cy={drawY}
		r={2}
		opacity={0.5}
		key={nanoid()}
		data-testid="default-point"
	/>
);

interface Overrides {
	svg?: CSSProperties;
	wrapper?: CSSProperties;
}

interface ScatterplotProps {
	xAxis?: boolean;
	yAxis?: boolean;
	topLabel?: string;
	leftLabel?: string;
	overrides?: Overrides;
	xMin: number;
	xMax: number;
	yMax: number;
	points: LeafArray;
	svgId?: string;
	drawPoint?: PointDrawFn;
	top?: number;
	bottom?: number;
	left?: number;
	right?: number;
}

const Scatterplot: FunctionComponent<ScatterplotProps> = ({
	xAxis = true,
	yAxis = true,
	topLabel,
	leftLabel,
	overrides,
	xMin,
	xMax,
	yMax,
	points,
	drawPoint,
	svgId,
	...margins
}) => {
	const mergedMargins = { ...defaultGraphMargins, ...margins };

	const { top, bottom, left, right } = mergedMargins;

	const draw = drawPoint ?? drawDefaultPoint;

	const [size, setSizeEl] = useElementSize();

	const { width, height } = size;

	const tickCount = height / 50;

	const xScale = useLinearXScale({
		xMinRatio: 0.9,
		xMaxRatio: 1.1,
		xMin,
		xMax,
		width,
		left,
		right,
	});

	const yScale = useLinearYScale({
		yMaxRatio: 1.1,
		top,
		bottom,
		height: size.height,
		yMax,
	});

	const yAxisClass = useLinearYAxis(yScale, left, tickCount);

	const xAxisClass = useXAxis(xScale, height, bottom);

	const pointGroup = useMemo(
		() => (
			<g>
				{points.reduce((acc, pointArr) => {
					pointArr.forEach((point) => {
						acc.push(
							draw({
								...point,
								drawX: xScale(point.x),
								drawY: yScale(point.y),
							})
						);
					});
					return acc;
				}, [] as ReactElement[])}
			</g>
		),
		[points, draw, xScale, yScale]
	);

	return (
		<StyledSVGContainer ref={setSizeEl} style={overrides?.wrapper}>
			<svg
				viewBox={`0 0 ${width} ${height}`}
				id={svgId}
				style={overrides?.svg}
			>
				{topLabel && (
					<ChartTopLabel
						width={width}
						left={left}
						top={top}
						label={topLabel}
					/>
				)}
				{leftLabel && (
					<ChartLeftLabel
						height={height}
						bottom={bottom}
						label={leftLabel}
					/>
				)}
				{xAxis && <StyledAxisGroup className={xAxisClass} />}
				{yAxis && <StyledAxisGroup className={yAxisClass} />}
				{pointGroup}
			</svg>
		</StyledSVGContainer>
	);
};

export default Scatterplot;
