# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

from numpy.testing import assert_array_equal

import numpy as np

import pytest

from ..core import ccdmask
from ..ccddata import CCDData


def test_ccdmask_no_ccddata():
    # Fails when a simple list is given.
    with pytest.raises(ValueError):
        ccdmask([[0, 0, 0], [0, 0, 0], [0, 0, 0]])


def test_ccdmask_not_2d():
    # Fails when a CCDData has less than 2 dimensions
    with pytest.raises(ValueError):
        ccdmask(CCDData(np.ones(3), unit='adu'))

    # Fails when scalar
    with pytest.raises(ValueError):
        ccdmask(CCDData(np.array(10), unit='adu'))

    # Fails when more than 2d
    with pytest.raises(ValueError):
        ccdmask(CCDData(np.ones((3, 3, 3)), unit='adu'))


def test_ccdmask_pixels():
    flat1 = CCDData(np.array([[
        20044, 19829, 19936, 20162, 19948, 19965, 19919, 20004, 19951,
        20002, 19926, 20151, 19886, 20014, 19928, 20025, 19921, 19996,
        19912, 20017, 19969, 20103, 20161, 20110, 19977, 19922, 20004,
        19802, 20079, 19981, 20083, 19871],
       [20068, 20204, 20085, 20027, 20103, 19866, 20089, 19914, 20160,
        19884, 19956, 20095, 20004, 20075, 19899, 20016, 19995, 20178,
        19963, 20030, 20055, 20005, 20073, 19969, 19958, 20040, 19979,
        19938, 19986, 19957, 20172, 20054],
       [20099, 20180, 19912, 20050, 19930, 19930, 20036, 20006, 19833,
        19984, 19879, 19815, 20105, 20011, 19949, 20062, 19837, 20070,
        20047, 19855, 19956, 19928, 19878, 20102, 19940, 20001, 20082,
        20080, 20019, 19991, 19919, 20121],
       [20014, 20262, 19953, 20077, 19928, 20271, 19962, 20048, 20011,
        20054, 20112, 19931, 20125, 19899, 19993, 19939, 19916, 19998,
        19921, 19949, 20246, 20160, 19881, 19863, 19874, 19979, 19989,
        19901, 19850, 19931, 20001, 20167],
       [20131, 19991, 20073, 19945, 19980, 20021, 19938, 19964, 20002,
        20177, 19888, 19901, 19919, 19977, 20280, 20035, 20045, 19849,
        20169, 20074, 20113, 19993, 19965, 20026, 20018, 19966, 20023,
        19965, 19962, 20082, 20027, 20145],
       [20106, 20025, 19846, 19865, 19913, 20046, 19998, 20037, 19986,
        20048, 20005, 19790, 20011, 19985, 19959, 19882, 20085, 19978,
        19881, 19960, 20111, 19936, 19983, 19863, 19819, 19896, 19968,
        20134, 19824, 19990, 20146, 19886],
       [20162, 19997, 19966, 20110, 19822, 19923, 20029, 20129, 19936,
        19882, 20077, 20112, 20040, 20051, 20177, 19763, 20097, 19898,
        19832, 20061, 19919, 20056, 20010, 19929, 20010, 19995, 20124,
        19965, 19922, 19860, 20021, 19989],
       [20088, 20104, 19956, 19959, 20018, 19948, 19836, 20107, 19920,
        20117, 19882, 20039, 20206, 20067, 19784, 20087, 20117, 19990,
        20242, 19861, 19923, 19779, 20024, 20024, 19981, 19915, 20017,
        20053, 19932, 20179, 20062, 19908],
       [19993, 20047, 20008, 20172, 19977, 20054, 19980, 19952, 20138,
        19940, 19995, 20029, 19888, 20191, 19958, 20007, 19938, 19959,
        19933, 20139, 20069, 19905, 20101, 20086, 19904, 19807, 20131,
        20048, 19927, 19905, 19939, 20030],
       [20040, 20051, 19997, 20013, 19942, 20130, 19983, 19603, 19934,
        19944, 19961, 19979, 20164, 19855, 20157, 20010, 20020, 19902,
        20134, 19971, 20228, 19967, 19879, 20022, 19915, 20063, 19768,
        19976, 19860, 20041, 19955, 19984],
       [19807, 20066, 19986, 19999, 19975, 20115, 19998, 20056, 20059,
        20016, 19970, 19964, 20053, 19975, 19985, 19973, 20041, 19918,
        19875, 19997, 19954, 19777, 20117, 20248, 20034, 20019, 20018,
        20058, 20027, 20121, 19909, 20094],
       [19890, 20018, 20032, 20058, 19909, 19906, 19812, 20206, 19908,
        19767, 20127, 20015, 19959, 20026, 20021, 19964, 19824, 19934,
        20147, 19984, 20026, 20168, 19992, 20175, 20040, 20208, 20077,
        19897, 20037, 19996, 19998, 20019],
       [19966, 19897, 20062, 19914, 19780, 20004, 20029, 20140, 20057,
        20134, 20125, 19973, 19894, 19929, 19876, 20135, 19981, 20057,
        20015, 20113, 20107, 20115, 19924, 19987, 19926, 19885, 20013,
        20058, 19950, 20155, 19825, 20092],
       [19889, 20046, 20113, 19991, 19829, 20180, 19949, 20011, 20014,
        20123, 19980, 19770, 20086, 20041, 19957, 19949, 20026, 19918,
        19777, 20062, 19862, 20085, 20090, 20122, 19692, 19937, 19897,
        20018, 19935, 20037, 19946, 19998],
       [20001, 19940, 19994, 19835, 19959, 19895, 20017, 20002, 20007,
        19851, 19900, 20044, 20354, 19814, 19869, 20148, 20001, 20143,
        19778, 20146, 19975, 19859, 20008, 20041, 19937, 20072, 20203,
        19778, 20027, 20075, 19877, 19999],
       [19753, 19866, 20037, 20149, 20020, 20071, 19955, 20164, 19837,
        19967, 19959, 20163, 20003, 20127, 20065, 20118, 20104, 19839,
        20124, 20057, 19943, 20023, 20138, 19996, 19910, 20048, 20070,
        19833, 19913, 20012, 19897, 19983]]), unit='adu')
    flat2 = CCDData(np.array([[
        20129, 20027, 19945, 20085, 19951, 20015, 20102, 19957, 20100,
        19865, 19878, 20111, 20047, 19882, 19929, 20079, 19937, 19999,
        20109, 19929, 19985, 19970, 19941, 19868, 20191, 20142, 19948,
        20079, 19975, 19949, 19972, 20053],
       [20075, 19980, 20035, 20014, 19865, 20058, 20091, 20030, 19931,
        19806, 19990, 19902, 19895, 19789, 20079, 20048, 20040, 19968,
        20049, 19946, 19982, 19865, 19766, 19903, 20025, 19916, 19904,
        20128, 19865, 20103, 19864, 19832],
       [20008, 19989, 20032, 19891, 20063, 20061, 20179, 19920, 19960,
        19655, 19897, 19943, 20015, 20123, 20009, 19940, 19876, 19964,
        20097, 19814, 20086, 20096, 20030, 20140, 19903, 19858, 19978,
        19817, 20107, 19893, 19988, 19956],
       [20105, 19873, 20003, 19671, 19993, 19981, 20234, 19976, 20079,
        19882, 19982, 19959, 19882, 20103, 20008, 19960, 20084, 20025,
        19864, 19969, 19945, 19979, 19937, 19965, 19981, 19957, 19906,
        19959, 19839, 19679, 19988, 20154],
       [20053, 20152, 19858, 20134, 19867, 20027, 20024, 19884, 20015,
        19904, 19992, 20137, 19981, 20147, 19814, 20035, 19992, 19921,
        20007, 20103, 19920, 19889, 20182, 19964, 19859, 20016, 20011,
        20203, 19761, 19954, 20151, 19973],
       [20029, 19863, 20217, 19819, 19984, 19950, 19914, 20028, 19980,
        20033, 20016, 19796, 19901, 20027, 20078, 20136, 19995, 19915,
        20014, 19920, 19996, 20216, 19939, 19967, 19949, 20023, 20024,
        19949, 19949, 19902, 19980, 19895],
       [19962, 19872, 19926, 20047, 20136, 19944, 20151, 19956, 19958,
        20054, 19942, 20010, 19972, 19936, 20062, 20259, 20230, 19927,
        20004, 19963, 20095, 19866, 19942, 19958, 20149, 19956, 20000,
        19979, 19949, 19892, 20249, 20050],
       [20019, 19999, 19954, 20095, 20045, 20002, 19761, 20187, 20113,
        20048, 20117, 20002, 19938, 19968, 19993, 19995, 20094, 19913,
        19963, 19813, 20040, 19950, 19992, 19958, 20043, 19925, 20036,
        19930, 20057, 20055, 20040, 19937],
       [19958, 19984, 19842, 19990, 19985, 19958, 20070, 19850, 20026,
        20047, 20081, 20094, 20048, 20048, 19917, 19893, 19766, 19765,
        20109, 20067, 19905, 19870, 19832, 20019, 19868, 20075, 20132,
        19916, 19944, 19840, 20140, 20117],
       [19995, 20122, 19998, 20039, 20125, 19879, 19911, 20010, 19944,
        19994, 19903, 20057, 20021, 20139, 19972, 20026, 19922, 20132,
        19976, 20025, 19948, 20038, 19807, 19809, 20145, 20003, 20090,
        19848, 19884, 19936, 19997, 19944],
       [19839, 19990, 20005, 19826, 20070, 19987, 20015, 19835, 20083,
        19908, 19910, 20218, 19960, 19937, 19987, 19808, 19893, 19929,
        20004, 20055, 19973, 19794, 20242, 20082, 20110, 20058, 19876,
        20042, 20064, 19966, 20041, 20015],
       [20048, 20203, 19855, 20011, 19888, 19926, 19973, 19893, 19986,
        20152, 20030, 19880, 20012, 19848, 19959, 20002, 20027, 19935,
        19975, 19905, 19932, 20190, 20188, 19903, 20012, 19943, 19954,
        19891, 19947, 19939, 19974, 19808],
       [20102, 20041, 20013, 20097, 20101, 19859, 20011, 20144, 19920,
        19880, 20134, 19963, 19980, 20090, 20027, 19822, 20051, 19903,
        19784, 19845, 20014, 19974, 20043, 20141, 19968, 20055, 20066,
        20045, 20182, 20104, 20008, 19999],
       [19932, 20023, 20042, 19894, 20070, 20015, 20172, 20024, 19988,
        20181, 20180, 20023, 19978, 19989, 19976, 19870, 20152, 20003,
        19984, 19903, 19904, 19940, 19990, 19922, 19911, 19976, 19841,
        19946, 20273, 20085, 20142, 20122],
       [19959, 20071, 20020, 20037, 20024, 19967, 20044, 20009, 19997,
        20045, 19995, 19831, 20035, 19976, 20049, 19958, 20021, 19887,
        19961, 19928, 19805, 20173, 19928, 19939, 19826, 20096, 20078,
        20100, 19935, 19942, 19969, 19941],
       [19876, 20056, 20071, 19886, 19979, 20174, 19978, 20037, 19933,
        20184, 19948, 20034, 19896, 19905, 20138, 19870, 19936, 20085,
        19971, 20063, 19936, 19941, 19928, 19937, 19970, 19931, 20036,
        19965, 19855, 19949, 19965, 19821]]), unit='adu')

    target_mask = np.zeros(flat1.shape, dtype=np.bool)

    # No bad pixels in this scenario
    ratio = flat1.divide(flat2)
    mask = ccdmask(ratio, ncsig=9, nlsig=11)
    assert mask.shape == ratio.shape
    assert_array_equal(mask, target_mask)

    # Check again with different ncsig and nlsig
    ratio = flat1.divide(flat2)
    mask = ccdmask(ratio, ncsig=11, nlsig=15)
    assert mask.shape == ratio.shape
    assert_array_equal(mask, target_mask)

    # Add single bad pixel
    flat1.data[14][3] = 65535
    flat2.data[14][3] = 1
    ratio = flat1.divide(flat2)
    mask = ccdmask(ratio, ncsig=11, nlsig=15)
    target_mask[14][3] = True
    assert_array_equal(mask, target_mask)

    # Add single bad column
    flat1.data[:, 7] = 65535
    flat2.data[:, 7] = 1
    ratio = flat1.divide(flat2)
    target_mask[:, 7] = True

    mask = ccdmask(ratio, ncsig=11, nlsig=15)
    assert_array_equal(mask, target_mask)

    mask = ccdmask(ratio, ncsig=11, nlsig=15, byblocks=True)
    assert_array_equal(mask, target_mask)

    mask = ccdmask(ratio, ncsig=11, nlsig=15, findbadcolumns=True)
    assert_array_equal(mask, target_mask)

    mask = ccdmask(ratio, ncsig=11, nlsig=15, findbadcolumns=True,
                   byblocks=True)
    assert_array_equal(mask, target_mask)

    # Add bad column with gaps
    flat1.data[0:8, 2] = 65535
    flat1.data[11:, 2] = 65535
    flat2.data[0:8, 2] = 1
    flat2.data[11:, 2] = 1
    ratio = flat1.divide(flat2)
    mask = ccdmask(ratio, ncsig=11, nlsig=15, findbadcolumns=False)
    target_mask[0:8, 2] = True
    target_mask[11:, 2] = True
    assert_array_equal(mask, target_mask)

    mask = ccdmask(ratio, ncsig=11, nlsig=15, findbadcolumns=True)
    target_mask[:, 2] = True
    assert_array_equal(mask, target_mask)
