from __future__ import with_statement
import random
import re
import socket

import mock
from nose.tools import eq_

from statsd import StatsClient


ADDR = (socket.gethostbyname('localhost'), 8125)


def _client(prefix=None):
    sc = StatsClient(host=ADDR[0], port=ADDR[1], prefix=prefix)
    sc._sock = mock.Mock()
    return sc


def _sock_check(cl, count, val):
    eq_(cl._sock.sendto.call_count, count)
    if val:
        val = val.encode('ascii')
        eq_(cl._sock.sendto.call_args, ((val, ADDR), {}))
    else:
        eq_(cl._sock.sendto.call_args, None)


class assert_raises(object):
    """A context manager that asserts a given exception was raised.

    >>> with assert_raises(TypeError):
    ...     raise TypeError

    >>> with assert_raises(TypeError):
    ...     raise ValueError
    AssertionError: ValueError not in ['TypeError']

    >>> with assert_raises(TypeError):
    ...     pass
    AssertionError: No exception raised.

    Or you can specify any of a number of exceptions:

    >>> with assert_raises(TypeError, ValueError):
    ...     raise ValueError

    >>> with assert_raises(TypeError, ValueError):
    ...     raise KeyError
    AssertionError: KeyError not in ['TypeError', 'ValueError']

    You can also get the exception back later:

    >>> with assert_raises(TypeError) as cm:
    ...     raise TypeError('bad type!')
    >>> cm.exception
    TypeError('bad type!')
    >>> cm.exc_type
    TypeError
    >>> cm.traceback
    <traceback @ 0x3323ef0>

    Lowercase name because that it's a class is an implementation detail.

    """

    def __init__(self, *exc_cls):
        self.exc_cls = exc_cls

    def __enter__(self):
        # For access to the exception later.
        return self

    def __exit__(self, typ, value, tb):
        assert typ, 'No exception raised.'
        assert typ in self.exc_cls, '%s not in %s' % (
            typ.__name__, [e.__name__ for e in self.exc_cls])
        self.exc_type = typ
        self.exception = value
        self.traceback = tb

        # Swallow expected exceptions.
        return True


@mock.patch.object(random, 'random', lambda: -1)
def test_incr():
    sc = _client()

    sc.incr('foo')
    _sock_check(sc, 1, 'foo:1|c')

    sc.incr('foo', 10)
    _sock_check(sc, 2, 'foo:10|c')

    sc.incr('foo', 1.2)
    _sock_check(sc, 3, 'foo:1.2|c')

    sc.incr('foo', 10, rate=0.5)
    _sock_check(sc, 4, 'foo:10|c|@0.5')


@mock.patch.object(random, 'random', lambda: -1)
def test_decr():
    sc = _client()

    sc.decr('foo')
    _sock_check(sc, 1, 'foo:-1|c')

    sc.decr('foo', 10)
    _sock_check(sc, 2, 'foo:-10|c')

    sc.decr('foo', 1.2)
    _sock_check(sc, 3, 'foo:-1.2|c')

    sc.decr('foo', 1, rate=0.5)
    _sock_check(sc, 4, 'foo:-1|c|@0.5')


@mock.patch.object(random, 'random', lambda: -1)
def test_gauge():
    sc = _client()
    sc.gauge('foo', 30)
    _sock_check(sc, 1, 'foo:30|g')

    sc.gauge('foo', 1.2)
    _sock_check(sc, 2, 'foo:1.2|g')

    sc.gauge('foo', 70, rate=0.5)
    _sock_check(sc, 3, 'foo:70|g|@0.5')


def test_gauge_delta():
    sc = _client()
    sc.gauge('foo', 12, delta=True)
    _sock_check(sc, 1, 'foo:+12|g')

    sc.gauge('foo', -13, delta=True)
    _sock_check(sc, 2, 'foo:-13|g')

    sc.gauge('foo', 1.2, delta=True)
    _sock_check(sc, 3, 'foo:+1.2|g')

    sc.gauge('foo', -1.3, delta=True)
    _sock_check(sc, 4, 'foo:-1.3|g')


@mock.patch.object(random, 'random', lambda: -1)
def test_timing():
    sc = _client()

    sc.timing('foo', 100)
    _sock_check(sc, 1, 'foo:100|ms')

    sc.timing('foo', 350)
    _sock_check(sc, 2, 'foo:350|ms')

    sc.timing('foo', 100, rate=0.5)
    _sock_check(sc, 3, 'foo:100|ms|@0.5')


def test_prepare():
    sc = _client(None)

    tests = (
        ('foo:1|c', ('foo', '1|c', 1)),
        ('bar:50|ms|@0.5', ('bar', '50|ms', 0.5)),
        ('baz:23|g', ('baz', '23|g', 1)),
    )

    def _check(o, s, v, r):
        with mock.patch.object(random, 'random', lambda: -1):
            eq_(o, sc._prepare(s, v, r))

    for o, (s, v, r) in tests:
        yield _check, o, s, v, r


def test_prefix():
    sc = _client('foo')

    sc.incr('bar')
    _sock_check(sc, 1, 'foo.bar:1|c')


def _timer_check(cl, count, start, end):
    eq_(cl._sock.sendto.call_count, count)
    value = cl._sock.sendto.call_args[0][0].decode('ascii')
    exp = re.compile('^%s:\d+|%s$' % (start, end))
    assert exp.match(value)


def test_timer_manager():
    """StatsClient.timer is a context manager."""
    sc = _client()

    with sc.timer('foo'):
        pass

    _timer_check(sc, 1, 'foo', 'ms')


def test_timer_decorator():
    """StatsClient.timer is a decorator."""
    sc = _client()

    @sc.timer('bar')
    def bar():
        pass

    bar()

    _timer_check(sc, 1, 'bar', 'ms')


def test_timer_capture():
    """You can capture the output of StatsClient.timer."""
    sc = _client()
    with sc.timer('woo') as result:
        eq_(result.ms, None)
    assert isinstance(result.ms, int)


@mock.patch.object(random, 'random', lambda: -1)
def test_timer_context_rate():
    sc = _client()

    with sc.timer('foo', rate=0.5):
        pass

    _timer_check(sc, 1, 'foo', 'ms|@0.5')


@mock.patch.object(random, 'random', lambda: -1)
def test_timer_decorator_rate():
    sc = _client()

    @sc.timer('bar', rate=0.1)
    def bar():
        pass

    bar()

    _timer_check(sc, 1, 'bar', 'ms|@0.1')


def test_timer_context_exceptions():
    """Exceptions within a managed block should get logged and propagate."""
    sc = _client()

    with assert_raises(socket.timeout):
        with sc.timer('foo'):
            raise socket.timeout()

    _timer_check(sc, 1, 'foo', 'ms')


def test_timer_decorator_exceptions():
    """Exceptions from wrapped methods should get logged and propagate."""
    sc = _client()

    @sc.timer('foo')
    def foo():
        raise ValueError()

    with assert_raises(ValueError):
        foo()

    _timer_check(sc, 1, 'foo', 'ms')


def test_pipeline():
    sc = _client()
    pipe = sc.pipeline()
    pipe.incr('foo')
    pipe.decr('bar')
    pipe.timing('baz', 320)
    pipe.send()
    _sock_check(sc, 1, 'foo:1|c\nbar:-1|c\nbaz:320|ms')


def test_pipeline_manager():
    sc = _client()
    with sc.pipeline() as pipe:
        pipe.incr('foo')
        pipe.decr('bar')
        pipe.gauge('baz', 15)
    _sock_check(sc, 1, 'foo:1|c\nbar:-1|c\nbaz:15|g')


def test_pipeline_timer_manager():
    sc = _client()
    with sc.pipeline() as pipe:
        with pipe.timer('foo'):
            pass
    _timer_check(sc, 1, 'foo', 'ms')


def test_pipeline_timer_decorator():
    sc = _client()
    with sc.pipeline() as pipe:
        @pipe.timer('foo')
        def foo():
            pass
        foo()
    _timer_check(sc, 1, 'foo', 'ms')


def test_pipeline_empty():
    """Pipelines should be empty after a send() call."""
    sc = _client()
    with sc.pipeline() as pipe:
        pipe.incr('foo')
        eq_(1, len(pipe._stats))
    eq_(0, len(pipe._stats))


def test_pipeline_packet_size():
    """Pipelines shouldn't send packets larger than 512 bytes."""
    sc = _client()
    pipe = sc.pipeline()
    for x in range(32):
        # 32 * 16 = 512, so this will need 2 packets.
        pipe.incr('sixteen_char_str')
    pipe.send()
    eq_(2, sc._sock.sendto.call_count)
    assert len(sc._sock.sendto.call_args_list[0][0][0]) <= 512
    assert len(sc._sock.sendto.call_args_list[1][0][0]) <= 512
