# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx_bsl.sketches.QuantilesSketch."""

import itertools
import pickle

import numpy as np
import pyarrow as pa
from tfx_bsl import sketches

from absl.testing import absltest
from absl.testing import parameterized


_QUANTILES_TEST_CASES = [
    dict(
        testcase_name="unweighted",
        values=[
            pa.array(np.linspace(1, 100, 100, dtype=np.float64)),
            pa.array(np.linspace(101, 200, 100, dtype=np.float64)),
            pa.array(np.linspace(201, 300, 100, dtype=np.float64)),
        ],
        expected=[[1, 61, 121, 181, 241, 300]],
        num_streams=1),
    dict(
        testcase_name="unweighted_elementwise",
        values=[
            pa.array(np.linspace(1, 500, 500, dtype=np.float64)),
            pa.array(np.linspace(101, 600, 500, dtype=np.float64)),
            pa.array(np.linspace(201, 700, 500, dtype=np.float64)),
        ],
        expected=[[1, 201, 301, 401, 501, 696], [2, 202, 302, 402, 502, 697],
                  [3, 203, 303, 403, 503, 698], [4, 204, 304, 404, 504, 699],
                  [5, 205, 305, 405, 505, 700]],
        num_streams=5),
    dict(
        testcase_name="weighted",
        values=[
            pa.array(np.linspace(1, 100, 100, dtype=np.float64)),
            pa.array(np.linspace(101, 200, 100, dtype=np.float64)),
            pa.array(np.linspace(201, 300, 100, dtype=np.float64)),
        ],
        weights=[
            pa.array([1] * 100, type=pa.float64()),
            pa.array([2] * 100, type=pa.float64()),
            pa.array([3] * 100, type=pa.float64()),
        ],
        expected=[[1, 111, 171, 221, 261, 300]],
        num_streams=1),
    dict(
        testcase_name="weighted_elementwise",
        values=[
            pa.array(np.linspace(1, 500, 500, dtype=np.float64)),
            pa.array(np.linspace(101, 600, 500, dtype=np.float64)),
            pa.array(np.linspace(201, 700, 500, dtype=np.float64)),
        ],
        weights=[
            pa.array([1] * 100, type=pa.float64()),
            pa.array([2] * 100, type=pa.float64()),
            pa.array([3] * 100, type=pa.float64()),
        ],
        expected=[[1, 231, 331, 431, 541, 696], [2, 232, 332, 432, 542, 697],
                  [3, 233, 333, 433, 543, 698], [4, 234, 334, 434, 544, 699],
                  [5, 235, 335, 435, 545, 700]],
        num_streams=5),
    dict(
        testcase_name="infinity",
        values=[
            pa.array(
                [1.0, 2.0, np.inf, np.inf, -np.inf, 3.0, 4.0, 5.0, -np.inf]),
            pa.array([1.0, np.inf, -np.inf]),
        ],
        expected=[[-np.inf, -np.inf, 1, 4, np.inf, np.inf]],
        num_streams=1),
    dict(
        testcase_name="null",
        values=[
            pa.array(np.linspace(1, 100, 100, dtype=np.float64)),
            pa.array(np.linspace(101, 200, 100, dtype=np.float64)),
            pa.array(np.linspace(201, 300, 100, dtype=np.float64)),
            pa.array([None, None]),
        ],
        expected=[[1, 61, 121, 181, 241, 300]],
        num_streams=1),
    dict(
        testcase_name="int",
        values=[
            pa.array(np.linspace(1, 100, 100, dtype=np.int32)),
            pa.array(np.linspace(101, 200, 100, dtype=np.int32)),
            pa.array(np.linspace(201, 300, 100, dtype=np.int32)),
        ],
        expected=[[1, 61, 121, 181, 241, 300]],
        num_streams=1),
    dict(
        testcase_name="negative_weights",
        values=[
            pa.array(np.linspace(1, 100, 100, dtype=np.float64)),
            pa.array(np.linspace(101, 200, 100, dtype=np.float64)),
            pa.array(np.linspace(201, 300, 100, dtype=np.float64)),
            pa.array([100000000, 200000000]),
        ],
        weights=[
            pa.array([1] * 100, type=pa.float64()),
            pa.array([2] * 100, type=pa.float64()),
            pa.array([3] * 100, type=pa.float64()),
            pa.array([0, -1]),
        ],
        expected=[[1, 111, 171, 221, 261, 300]],
        num_streams=1),
]

_MAX_NUM_ELEMENTS = [2**10, 2**14, 2**18, 2**19]
_EPS = [0.5, 0.01, 0.001, 0.0001, 0.000001]
_NUM_QUANTILES = [5**4, 3**6, 1000, 2**10]

_ACCURACY_TEST_CASES = list(
    itertools.product(_MAX_NUM_ELEMENTS, _EPS, _NUM_QUANTILES))


def _add_values(sketch, value, weight):
  if weight is None:
    sketch.AddValues(value)
  else:
    sketch.AddValues(value, weight)


def _pickle_roundtrip(s):
  return pickle.loads(pickle.dumps(s))


class QuantilesSketchTest(parameterized.TestCase):

  def assert_quantiles_accuracy(self, quantiles, cdf, eps):
    # Helper function to validate quantiles accuracy given a cdf function.
    # This function assumes that quantiles input values are of the form
    # range(N). Note that this function also validates order of quantiles since
    # their cdf values are compared to ordered expected levels.
    num_quantiles = len(quantiles)
    expected_levels = [i / (num_quantiles - 1) for i in range(num_quantiles)]
    for level, quantile in zip(expected_levels, quantiles):
      quantile_cdf = cdf(quantile)
      left_cdf = cdf(quantile - 1)
      right_cdf = cdf(quantile + 1)
      error_msg = (
          "Accuracy of the given quantile is not sufficient, "
          "quantile={} of expected level {}, its cdf is {}; cdf of a value to "
          "the left is {}, to the right is {}. Error bound = {}.").format(
              quantile, level, quantile_cdf, left_cdf, right_cdf, eps)
      self.assertTrue(
          abs(level - cdf(quantile)) < eps or left_cdf < level < right_cdf or
          (level == 0 and left_cdf == 0), error_msg)

  def test_quantiles_sketch_init(self):
    with self.assertRaisesRegex(RuntimeError, "eps must be positive"):
      _ = sketches.QuantilesSketch(0, 1 << 32, 1)

    with self.assertRaisesRegex(RuntimeError, "max_num_elements must be >= 1."):
      _ = sketches.QuantilesSketch(0.0001, 0, 1)

    with self.assertRaisesRegex(RuntimeError, "num_streams must be >= 1."):
      _ = sketches.QuantilesSketch(0.0001, 1 << 32, 0)

    _ = sketches.QuantilesSketch(0.0001, 1 << 32, 1)

  @parameterized.named_parameters(*_QUANTILES_TEST_CASES)
  def test_quantiles(self, values, expected, num_streams, weights=None):
    s = sketches.QuantilesSketch(0.00001, 1 << 32, num_streams)
    if weights is None:
      weights = [None] * len(values)
    for value, weight in zip(values, weights):
      _add_values(s, value, weight)

    result = s.GetQuantiles(len(expected[0]) - 1).to_pylist()
    np.testing.assert_almost_equal(expected, result)

  @parameterized.named_parameters(*_QUANTILES_TEST_CASES)
  def test_pickle(self, values, expected, num_streams, weights=None):
    s = sketches.QuantilesSketch(0.00001, 1 << 32, num_streams)
    if weights is None:
      weights = [None] * len(values)
    for value, weight in zip(values, weights):
      _add_values(s, value, weight)
    pickled = pickle.dumps(s)
    self.assertIsInstance(pickled, bytes)
    unpickled = pickle.loads(pickled)
    self.assertIsInstance(unpickled, sketches.QuantilesSketch)
    result = unpickled.GetQuantiles(len(expected[0]) - 1).to_pylist()
    np.testing.assert_almost_equal(expected, result)

  @parameterized.named_parameters(*_QUANTILES_TEST_CASES)
  def test_merge(self, values, expected, num_streams, weights=None):
    if weights is None:
      weights = [None] * len(values)
    s1 = sketches.QuantilesSketch(0.00001, 1 << 32, num_streams)
    for value, weight in zip(values[:len(values) // 2],
                             weights[:len(weights) // 2]):
      _add_values(s1, value, weight)
    s2 = sketches.QuantilesSketch(0.00001, 1 << 32, num_streams)
    for value, weight in zip(values[len(values) // 2:],
                             weights[len(weights) // 2:]):
      _add_values(s2, value, weight)

    s1 = _pickle_roundtrip(s1)
    s2 = _pickle_roundtrip(s2)
    s1.Merge(s2)

    result = s1.GetQuantiles(len(expected[0]) - 1).to_pylist()
    np.testing.assert_almost_equal(expected, result)

  @parameterized.named_parameters(*_QUANTILES_TEST_CASES)
  def test_compact(self, values, expected, num_streams, weights=None):
    s = sketches.QuantilesSketch(0.00001, 1 << 32, num_streams)
    num_values = len(values)
    if weights is None:
      weights = [None] * num_values
    for value, weight in zip(values[:num_values // 2],
                             weights[:num_values // 2]):
      _add_values(s, value, weight)
    s.Compact()
    for value, weight in zip(values[num_values // 2:],
                             weights[num_values // 2:]):
      _add_values(s, value, weight)
    s.Compact()

    result = s.GetQuantiles(len(expected[0]) - 1).to_pylist()
    np.testing.assert_almost_equal(expected, result)

  @parameterized.parameters(*_ACCURACY_TEST_CASES)
  def test_accuracy(self, max_num_elements, eps, num_quantiles):
    s = sketches.QuantilesSketch(eps, max_num_elements, 1)
    values = pa.array(reversed(range(max_num_elements)))
    weights = pa.array(range(max_num_elements))
    total_weight = (max_num_elements - 1) * max_num_elements / 2

    def cdf(x):
      left_weight = (2 * (max_num_elements - 1) - x) * (x + 1) / 2
      return left_weight / total_weight

    _add_values(s, values, weights)
    quantiles = s.GetQuantiles(num_quantiles - 1).to_pylist()[0]
    self.assert_quantiles_accuracy(quantiles, cdf, eps)

  @parameterized.parameters(*_ACCURACY_TEST_CASES)
  def test_accuracy_after_pickle(self, max_num_elements, eps, num_quantiles):
    s = sketches.QuantilesSketch(eps, max_num_elements, 1)
    values = pa.array(reversed(range(max_num_elements)))
    weights = pa.array(range(max_num_elements))
    total_weight = (max_num_elements - 1) * max_num_elements / 2

    def cdf(x):
      left_weight = (2 * (max_num_elements - 1) - x) * (x + 1) / 2
      return left_weight / total_weight

    _add_values(s, values[:max_num_elements // 2],
                weights[:max_num_elements // 2])
    s = _pickle_roundtrip(s)
    _add_values(s, values[max_num_elements // 2:],
                weights[max_num_elements // 2:])
    s = _pickle_roundtrip(s)
    quantiles = s.GetQuantiles(num_quantiles - 1).to_pylist()[0]
    self.assert_quantiles_accuracy(quantiles, cdf, eps)

  @parameterized.parameters(*_ACCURACY_TEST_CASES)
  def test_accuracy_after_merge(self, max_num_elements, eps, num_quantiles):
    s1 = sketches.QuantilesSketch(eps, max_num_elements, 1)
    s2 = sketches.QuantilesSketch(eps, max_num_elements, 1)
    s3 = sketches.QuantilesSketch(eps, max_num_elements, 1)
    values = pa.array(reversed(range(max_num_elements)))
    weights = pa.array(range(max_num_elements))
    total_weight = (max_num_elements - 1) * max_num_elements / 2

    def cdf(x):
      left_weight = (2 * (max_num_elements - 1) - x) * (x + 1) / 2
      return left_weight / total_weight

    _add_values(s1, values[:max_num_elements // 10],
                weights[:max_num_elements // 10])
    _add_values(s2, values[max_num_elements // 10:max_num_elements // 3],
                weights[max_num_elements // 10:max_num_elements // 3])
    _add_values(s3, values[max_num_elements // 3:],
                weights[max_num_elements // 3:])
    s2.Merge(s3)
    s1.Merge(s2)
    quantiles = s1.GetQuantiles(num_quantiles - 1).to_pylist()[0]
    self.assert_quantiles_accuracy(quantiles, cdf, eps)

  @parameterized.parameters(*_ACCURACY_TEST_CASES)
  def test_accuracy_after_compact(self, max_num_elements, eps, num_quantiles):
    s1 = sketches.QuantilesSketch(eps, max_num_elements, 1)
    s2 = sketches.QuantilesSketch(eps, max_num_elements, 1)
    s3 = sketches.QuantilesSketch(eps, max_num_elements, 1)
    values = pa.array(reversed(range(max_num_elements)))
    weights = pa.array(range(max_num_elements))
    total_weight = (max_num_elements - 1) * max_num_elements / 2

    def cdf(x):
      left_weight = (2 * (max_num_elements - 1) - x) * (x + 1) / 2
      return left_weight / total_weight

    _add_values(s1, values[:max_num_elements // 10],
                weights[:max_num_elements // 10])
    _add_values(s2, values[max_num_elements // 10:max_num_elements // 3],
                weights[max_num_elements // 10:max_num_elements // 3])
    _add_values(s3, values[max_num_elements // 3:],
                weights[max_num_elements // 3:])
    s2.Compact()
    s3.Compact()
    s2.Merge(s3)
    s2.Compact()
    s1.Compact()
    s1.Merge(s2)
    s1.Compact()
    quantiles = s1.GetQuantiles(num_quantiles - 1).to_pylist()[0]
    self.assert_quantiles_accuracy(quantiles, cdf, eps)


if __name__ == "__main__":
  absltest.main()
