from re import match
from os.path import join
from tempfile import gettempdir

from pygame import Surface
from pygame.font import Font
from pygame.draw import aaline
from pygame.locals import *

from GameChild import GameChild
from Sprite import Sprite
from Animation import Animation

class Interpolator(list, GameChild):

    def __init__(self, parent):
        GameChild.__init__(self, parent)
        self.set_nodesets()
        self.gui_enabled = self.check_command_line("-interpolator")
        if self.gui_enabled:
            self.gui = GUI(self)

    def set_nodesets(self):
        config = self.get_configuration()
        if config.has_section("interpolate"):
            for name, value in config.get_section("interpolate").iteritems():
                self.add_nodeset(name, value)

    def add_nodeset(self, name, value, method=None):
        self.append(Nodeset(name, value, method))
        return len(self) - 1

    def is_gui_active(self):
        return self.gui_enabled and self.gui.active

    def get_nodeset(self, name):
        for nodeset in self:
            if nodeset.name == name:
                return nodeset

    def remove(self, outgoing):
        for ii, nodeset in enumerate(self):
            if nodeset.name == outgoing.name:
                self.pop(ii)
                break


class Nodeset(list):

    LINEAR, CUBIC = range(2)

    def __init__(self, name, nodes, method=None):
        list.__init__(self, [])
        self.name = name
        if isinstance(nodes, str):
            self.parse_raw(nodes)
        else:
            self.interpolation_method = method
            self.parse_list(nodes)
        self.set_splines()

    def parse_raw(self, raw):
        raw = raw.strip()
        if raw[0].upper() == "L":
            self.set_interpolation_method(self.LINEAR, False)
        else:
            self.set_interpolation_method(self.CUBIC, False)
        for node in raw[1:].strip().split(","):
            self.add_node(map(float, node.strip().split()), False)

    def set_interpolation_method(self, method, refresh=True):
        self.interpolation_method = method
        if refresh:
            self.set_splines()

    def add_node(self, coordinates, refresh=True):
        x = coordinates[0]
        inserted = False
        index = 0
        for ii, node in enumerate(self):
            if x < node.x:
                self.insert(ii, Node(coordinates))
                inserted = True
                index = ii
                break
            elif x == node.x:
                return None
        if not inserted:
            self.append(Node(coordinates))
            index = len(self) - 1
        if refresh:
            self.set_splines()
        return index

    def parse_list(self, nodes):
        for node in nodes:
            self.add_node(node)

    def set_splines(self):
        if self.interpolation_method == self.LINEAR:
            self.set_linear_splines()
        else:
            self.set_cubic_splines()

    def set_linear_splines(self):
        self.splines = splines = []
        for ii in xrange(len(self) - 1):
            x1, y1, x2, y2 = self[ii] + self[ii + 1]
            m = float(y2 - y1) / (x2 - x1)
            splines.append(LinearSpline(x1, y1, m))

    def set_cubic_splines(self):
        n = len(self) - 1
        a = [node.y for node in self]
        b = [None] * n
        d = [None] * n
        h = [self[ii + 1].x - self[ii].x for ii in xrange(n)]
        alpha = [None] + [(3.0 / h[ii]) * (a[ii + 1] - a[ii]) - \
                          (3.0 / h[ii - 1]) * (a[ii] - a[ii - 1]) \
                          for ii in xrange(1, n)]
        c = [None] * (n + 1)
        l = [None] * (n + 1)
        u = [None] * (n + 1)
        z = [None] * (n + 1)
        l[0] = 1
        u[0] = z[0] = 0
        for ii in xrange(1, n):
            l[ii] = 2 * (self[ii + 1].x - self[ii - 1].x) - \
                    h[ii - 1] * u[ii - 1]
            u[ii] = h[ii] / l[ii]
            z[ii] = (alpha[ii] - h[ii - 1] * z[ii - 1]) / l[ii]
        l[n] = 1
        z[n] = c[n] = 0
        for jj in xrange(n - 1, -1, -1):
            c[jj] = z[jj] - u[jj] * c[jj + 1]
            b[jj] = (a[jj + 1] - a[jj]) / h[jj] - \
                    (h[jj] * (c[jj + 1] + 2 * c[jj])) / 3
            d[jj] = (c[jj + 1] - c[jj]) / (3 * h[jj])
        self.splines = [CubicSpline(self[ii].x, a[ii], b[ii], c[ii],
                                    d[ii]) for ii in xrange(n)]

    def get_y(self, t, loop=False, reverse=False, natural=False):
        if loop or reverse:
            if reverse and int(t) / int(self[-1].x) % 2:
                t = self[-1].x - t
            t %= self[-1].x
        elif not natural:
            if t < self[0].x:
                t = self[0].x
            elif t > self[-1].x:
                t = self[-1].x
        splines = self.splines
        for ii in xrange(len(splines) - 1):
            if t < splines[ii + 1].x:
                return splines[ii].get_y(t)
        return splines[-1].get_y(t)

    def remove(self, node, refresh=True):
        list.remove(self, node)
        if refresh:
            self.set_splines()

    def resize(self, left, length, refresh=True):
        old_left = self[0].x
        old_length = self.get_length()
        for node in self:
            node.x = left + length * (node.x - old_left) / old_length
        if refresh:
            self.set_splines()

    def get_length(self):
        return self[-1].x - self[0].x


class Node(list):

    def __init__(self, coordinates):
        list.__init__(self, coordinates)

    def __getattr__(self, name):
        if name == "x":
            return self[0]
        elif name == "y":
            return self[1]
        return list.__get__(self, name)

    def __setattr__(self, name, value):
        if name == "x":
            list.__setitem__(self, 0, value)
        elif name == "y":
            list.__setitem__(self, 1, value)
        else:
            list.__setattr__(self, name, value)


class Spline:

    def __init__(self, x):
        self.x = x


class CubicSpline(Spline):

    def __init__(self, x, a, b, c, d):
        Spline.__init__(self, x)
        self.a = a
        self.b = b
        self.c = c
        self.d = d

    def get_y(self, t):
        x = self.x
        return self.a + self.b * (t - x) + self.c * (t - x) ** 2 + self.d * \
               (t - x) ** 3


class LinearSpline(Spline):

    def __init__(self, x, y, m):
        Spline.__init__(self, x)
        self.y = y
        self.m = m

    def get_y(self, t):
        return self.m * (t - self.x) + self.y


class GUI(Animation):

    B_DUPLICATE, B_WRITE, B_DELETE, B_LINEAR, B_CUBIC, B_SPLIT = range(6)
    S_NONE, S_LEFT, S_RIGHT = range(3)

    def __init__(self, parent):
        Animation.__init__(self, parent, unfiltered=True)
        self.audio = self.get_audio()
        self.display = self.get_game().display
        self.display_surface = self.get_display_surface()
        self.time_filter = self.get_game().time_filter
        self.delegate = self.get_delegate()
        self.split = self.S_NONE
        self.success_indicator_active = True
        self.success_indicator_blink_count = 0
        self.load_configuration()
        self.font = Font(None, self.label_size)
        self.prompt = Prompt(self)
        self.set_temporary_file()
        self.set_background()
        self.set_success_indicator()
        self.set_plot_rect()
        self.set_marker_frame()
        self.set_buttons()
        self.active = False
        self.set_nodeset_index()
        self.set_y_range()
        self.set_markers()
        self.subscribe(self.respond_to_command)
        self.subscribe(self.respond_to_mouse_down, MOUSEBUTTONDOWN)
        self.subscribe(self.respond_to_key, KEYDOWN)
        self.register(self.show_success_indicator, interval=100)
        self.register(self.save_temporary_file, interval=10000)
        self.play(self.save_temporary_file)

    def load_configuration(self):
        config = self.get_configuration("interpolator-gui")
        self.label_size = config["label-size"]
        self.axis_label_count = config["axis-label-count"]
        self.margin = config["margin"]
        self.curve_color = config["curve-color"]
        self.marker_size = config["marker-size"]
        self.marker_color = config["marker-color"]
        self.label_precision = config["label-precision"]
        self.template_nodeset = config["template-nodeset"]
        self.template_nodeset_name = config["template-nodeset-name"]
        self.flat_y_range = config["flat-y-range"]

    def set_temporary_file(self):
        self.temporary_file = open(join(gettempdir(), "pgfw-config"), "w")

    def set_background(self):
        surface = Surface(self.display_surface.get_size())
        surface.fill((0, 0, 0))
        self.background = surface

    def set_success_indicator(self):
        surface = Surface((10, 10))
        surface.fill((0, 255, 0))
        rect = surface.get_rect()
        rect.topleft = self.display_surface.get_rect().topleft
        self.success_indicator, self.success_indicator_rect = surface, rect

    def set_plot_rect(self):
        margin = self.margin
        self.plot_rect = self.display_surface.get_rect().inflate(-margin,
                                                                 -margin)

    def set_marker_frame(self):
        size = self.marker_size
        surface = Surface((size, size))
        transparent_color = (255, 0, 255)
        surface.fill(transparent_color)
        surface.set_colorkey(transparent_color)
        line_color = self.marker_color
        aaline(surface, line_color, (0, 0), (size - 1, size - 1))
        aaline(surface, line_color, (0, size - 1), (size - 1, 0))
        self.marker_frame = surface

    def set_buttons(self):
        self.buttons = buttons = []
        text = "Duplicate", "Write", "Delete", "Linear", "Cubic", "Split: No"
        x = 0
        for instruction in text:
            buttons.append(Button(self, instruction, x))
            x += buttons[-1].location.w + 10

    def set_nodeset_index(self, increment=None, index=None):
        parent = self.parent
        if index is None:
            if not increment:
                index = 0
            else:
                index = self.nodeset_index + increment
                limit = len(parent) - 1
                if index > limit:
                    index = 0
                elif index < 0:
                    index = limit
        self.nodeset_index = index
        self.set_nodeset_label()

    def set_nodeset_label(self):
        surface = self.font.render(self.get_nodeset().name, True, (0, 0, 0),
                                   (255, 255, 255))
        rect = surface.get_rect()
        rect.bottomright = self.display_surface.get_rect().bottomright
        self.nodeset_label, self.nodeset_label_rect = surface, rect

    def get_nodeset(self):
        if not len(self.parent):
            self.parent.add_nodeset(self.template_nodeset_name,
                                    self.template_nodeset)
            self.set_nodeset_index(0)
        return self.parent[self.nodeset_index]

    def set_y_range(self):
        width = self.plot_rect.w
        nodeset = self.get_nodeset()
        self.y_range = y_range = [nodeset[0].y, nodeset[-1].y]
        x = 0
        while x < width:
            y = nodeset.get_y(self.get_function_coordinates(x)[0])
            if y < y_range[0]:
                y_range[0] = y
            elif y > y_range[1]:
                y_range[1] = y
            x += width * .01
        if y_range[1] - y_range[0] == 0:
            y_range[1] += self.flat_y_range
        if self.split:
            self.adjust_for_split(y_range, nodeset)
        self.set_axis_labels()

    def get_function_coordinates(self, xp=0, yp=0):
        nodeset = self.get_nodeset()
        x_min, x_max, (y_min, y_max) = nodeset[0].x, nodeset[-1].x, self.y_range
        rect = self.plot_rect
        x = float(xp) / (rect.right - rect.left) * (x_max - x_min) + x_min
        y = float(yp) / (rect.bottom - rect.top) * (y_min - y_max) + y_max
        return x, y

    def adjust_for_split(self, y_range, nodeset):
        middle = nodeset[0].y if self.split == self.S_LEFT else nodeset[-1].y
        below, above = middle - y_range[0], y_range[1] - middle
        if below > above:
            y_range[1] += below - above
        else:
            y_range[0] -= above - below

    def set_axis_labels(self):
        self.axis_labels = labels = []
        nodeset, formatted, render, rect, yr = (self.get_nodeset(),
                                                self.get_formatted_measure,
                                                self.font.render,
                                                self.plot_rect, self.y_range)
        for ii, node in enumerate(nodeset[0::len(nodeset) - 1]):
            xs = render(formatted(node.x), True, (0, 0, 0), (255, 255, 255))
            xsr = xs.get_rect()
            xsr.top = rect.bottom
            if not ii:
                xsr.left = rect.left
            else:
                xsr.right = rect.right
            ys = render(formatted(yr[ii]), True, (0, 0, 0), (255, 255, 255))
            ysr = ys.get_rect()
            ysr.right = rect.left
            if not ii:
                ysr.bottom = rect.bottom
            else:
                ysr.top = rect.top
            labels.append(((xs, xsr), (ys, ysr)))

    def get_formatted_measure(self, measure):
        return "%s" % float(("%." + str(self.label_precision) + "g") % measure)

    def deactivate(self):
        self.active = False
        self.time_filter.open()
        self.audio.muted = self.saved_mute_state
        self.display.set_mouse_visibility(self.saved_mouse_state)

    def respond_to_command(self, event):
        compare = self.delegate.compare
        if compare(event, "toggle-interpolator"):
            self.toggle()
        elif self.active:
            if compare(event, "reset-game"):
                self.deactivate()
            elif compare(event, "quit"):
                self.get_game().end(event)

    def toggle(self):
        if self.active:
            self.deactivate()
            self.get_game().delegate.post("refresh-nodesets")
        else:
            self.activate()

    def activate(self):
        self.active = True
        self.time_filter.close()
        self.saved_mute_state = self.audio.muted
        self.audio.mute()
        self.draw()
        self.saved_mouse_state = self.display.set_mouse_visibility(True)

    def respond_to_mouse_down(self, event):
        redraw = False
        if self.active and not self.prompt.active:
            nodeset_rect = self.nodeset_label_rect
            plot_rect = self.plot_rect
            if event.button == 1:
                pos = event.pos
                if nodeset_rect.collidepoint(pos):
                    self.set_nodeset_index(1)
                    redraw = True
                elif self.axis_labels[0][0][1].collidepoint(pos):
                    text = "{0} {1}".format(*map(self.get_formatted_measure,
                                                 self.get_nodeset()[0]))
                    self.prompt.activate(text, self.resize_nodeset, 0)
                elif self.axis_labels[1][0][1].collidepoint(pos):
                    text = "{0} {1}".format(*map(self.get_formatted_measure,
                                                 self.get_nodeset()[-1]))
                    self.prompt.activate(text, self.resize_nodeset, -1)
                else:
                    bi = self.collide_buttons(pos)
                    if bi is not None:
                        if bi == self.B_WRITE:
                            self.get_configuration().write()
                            self.play(self.show_success_indicator)
                        elif bi in (self.B_LINEAR, self.B_CUBIC):
                            nodeset = self.get_nodeset()
                            if bi == self.B_LINEAR:
                                nodeset.set_interpolation_method(Nodeset.LINEAR)
                            else:
                                nodeset.set_interpolation_method(Nodeset.CUBIC)
                            self.store_in_configuration()
                            redraw = True
                        elif bi == self.B_DUPLICATE:
                            self.prompt.activate("", self.add_nodeset)
                        elif bi == self.B_DELETE and len(self.parent) > 1:
                            self.parent.remove(self.get_nodeset())
                            self.set_nodeset_index(1)
                            self.store_in_configuration()
                            redraw = True
                        elif bi == self.B_SPLIT:
                            self.toggle_split()
                            redraw = True
                    elif plot_rect.collidepoint(pos) and \
                             not self.collide_markers(pos):
                        xp, yp = pos[0] - plot_rect.left, pos[1] - plot_rect.top
                        self.get_nodeset().add_node(
                            self.get_function_coordinates(xp, yp))
                        self.store_in_configuration()
                        redraw = True
            elif event.button == 3:
                pos = event.pos
                if nodeset_rect.collidepoint(pos):
                    self.set_nodeset_index(-1)
                    redraw = True
                elif plot_rect.collidepoint(pos):
                    marker = self.collide_markers(pos)
                    if marker:
                        self.get_nodeset().remove(marker.node)
                        self.store_in_configuration()
                        redraw = True
        elif self.active and self.prompt.active and \
                 not self.prompt.rect.collidepoint(event.pos):
            self.prompt.deactivate()
            redraw = True
        if redraw:
            self.set_y_range()
            self.set_markers()
            self.draw()

    def resize_nodeset(self, text, index):
        result = match("^\s*(-{,1}\d*\.{,1}\d*)\s+(-{,1}\d*\.{,1}\d*)\s*$",
                       text)
        if result:
            try:
                nodeset = self.get_nodeset()
                x, y = map(float, result.group(1, 2))
                if (index == -1 and x > nodeset[0].x) or \
                       (index == 0 and x < nodeset[-1].x):
                    nodeset[index].y = y
                    if index == -1:
                        nodeset.resize(nodeset[0].x, x - nodeset[0].x)
                    else:
                        nodeset.resize(x, nodeset[-1].x - x)
                    self.store_in_configuration()
                    self.set_y_range()
                    self.set_axis_labels()
                    self.set_markers()
                    self.draw()
                    return True
            except ValueError:
                return False

    def collide_buttons(self, pos):
        for ii, button in enumerate(self.buttons):
            if button.location.collidepoint(pos):
                return ii

    def store_in_configuration(self):
        config = self.get_configuration()
        section = "interpolate"
        config.clear_section(section)
        for nodeset in self.parent:
            code = "L" if nodeset.interpolation_method == Nodeset.LINEAR else \
                   "C"
            for ii, node in enumerate(nodeset):
                if ii > 0:
                    code += ","
                code += " {0} {1}".format(*map(self.get_formatted_measure,
                                               node))
            if not config.has_section(section):
                config.add_section(section)
            config.set(section, nodeset.name, code)

    def toggle_split(self):
        self.split += 1
        if self.split > self.S_RIGHT:
            self.split = self.S_NONE
        self.buttons[self.B_SPLIT].set_frame(["Split: No", "Split: L",
                                              "Split: R"][self.split])

    def add_nodeset(self, name):
        nodeset = self.get_nodeset()
        self.set_nodeset_index(index=self.parent.add_nodeset(\
            name, nodeset, nodeset.interpolation_method))
        self.store_in_configuration()
        self.draw()
        return True

    def collide_markers(self, pos):
        for marker in self.markers:
            if marker.location.collidepoint(pos):
                return marker

    def set_markers(self):
        self.markers = markers = []
        for node in self.get_nodeset()[1:-1]:
            markers.append(Marker(self, node))
            markers[-1].location.center = self.get_plot_coordinates(*node)

    def get_plot_coordinates(self, x=0, y=0):
        nodeset = self.get_nodeset()
        x_min, x_max, (y_min, y_max) = nodeset[0].x, nodeset[-1].x, self.y_range
        x_ratio = float(x - x_min) / (x_max - x_min)
        rect = self.plot_rect
        xp = x_ratio * (rect.right - rect.left) + rect.left
        y_ratio = float(y - y_min) / (y_max - y_min)
        yp = rect.bottom - y_ratio * (rect.bottom - rect.top)
        return xp, yp

    def draw(self):
        display_surface = self.display_surface
        display_surface.blit(self.background, (0, 0))
        display_surface.blit(self.nodeset_label, self.nodeset_label_rect)
        self.draw_axes()
        self.draw_function()
        self.draw_markers()
        self.draw_buttons()

    def draw_axes(self):
        display_surface = self.display_surface
        for xl, yl in self.axis_labels:
            display_surface.blit(*xl)
            display_surface.blit(*yl)

    def draw_function(self):
        rect = self.plot_rect
        surface = self.display_surface
        nodeset = self.get_nodeset()
        step = 1
        for x in xrange(rect.left, rect.right + step, step):
            ii = x - rect.left
            fx = nodeset.get_y(self.get_function_coordinates(ii)[0])
            y = self.get_plot_coordinates(y=fx)[1]
            if ii > 0:
                aaline(surface, self.curve_color, (x - step, last_y), (x, y))
            last_y = y

    def draw_markers(self):
        for marker in self.markers:
            marker.update()

    def draw_buttons(self):
        for button in self.buttons:
            button.update()

    def respond_to_key(self, event):
        if self.prompt.active:
            prompt = self.prompt
            if event.key == K_RETURN:
                if prompt.callback[0](prompt.text, *prompt.callback[1]):
                    prompt.deactivate()
            elif event.key == K_BACKSPACE:
                prompt.text = prompt.text[:-1]
                prompt.update()
                prompt.draw_text()
            elif (event.unicode.isalnum() or event.unicode.isspace() or \
                  event.unicode in (".", "-", "_")) and len(prompt.text) < \
                  prompt.character_limit:
                prompt.text += event.unicode
                prompt.update()
                prompt.draw_text()

    def show_success_indicator(self):
        self.draw()
        if self.success_indicator_blink_count > 1:
            self.success_indicator_blink_count = 0
            self.halt(self.show_success_indicator)
        else:
            if self.success_indicator_active:
                self.display_surface.blit(self.success_indicator,
                                          self.success_indicator_rect)
            if self.success_indicator_active:
                self.success_indicator_blink_count += 1
            self.success_indicator_active = not self.success_indicator_active

    def save_temporary_file(self):
        fp = self.temporary_file
        fp.seek(0)
        fp.truncate()
        self.get_configuration().write(fp)

    def rearrange(self):
        self.set_background()
        self.set_success_indicator()
        self.set_plot_rect()
        self.set_markers()
        self.set_nodeset_label()
        self.set_axis_labels()
        self.set_buttons()
        self.prompt.reset()

class Marker(Sprite):

    def __init__(self, parent, node):
        Sprite.__init__(self, parent)
        self.add_frame(parent.marker_frame)
        self.node = node


class Button(Sprite):

    def __init__(self, parent, text, left):
        Sprite.__init__(self, parent)
        self.set_frame(text)
        self.location.bottomleft = left, \
                                   self.get_display_surface().get_rect().bottom

    def set_frame(self, text):
        self.clear_frames()
        self.add_frame(self.parent.font.render(text, True, (0, 0, 0),
                                               (255, 255, 255)))


class Prompt(Sprite):

    def __init__(self, parent):
        Sprite.__init__(self, parent)
        self.load_configuration()
        self.font = Font(None, self.text_size)
        self.reset()
        self.deactivate()

    def deactivate(self):
        self.active = False

    def load_configuration(self):
        config = self.get_configuration("interpolator-gui")
        self.size = config["prompt-size"]
        self.border_color = config["prompt-border-color"]
        self.border_width = config["prompt-border-width"]
        self.character_limit = config["prompt-character-limit"]
        self.text_size = config["prompt-text-size"]

    def reset(self):
        self.set_frame()
        self.place()

    def set_frame(self):
        self.clear_frames()
        surface = Surface(self.size)
        self.add_frame(surface)
        surface.fill(self.border_color)
        width = self.border_width * 2
        surface.fill((0, 0, 0), surface.get_rect().inflate(-width, -width))

    def place(self):
        self.location.center = self.display_surface.get_rect().center

    def activate(self, text, callback, *args):
        self.active = True
        self.text = str(text)
        self.callback = callback, args
        self.update()
        self.draw_text()

    def draw_text(self):
        surface = self.font.render(self.text, True, (255, 255, 255), (0, 0, 0))
        rect = surface.get_rect()
        rect.center = self.location.center
        self.display_surface.blit(surface, rect)
18.222.20.30
18.222.20.30
18.222.20.30
 
June 7, 2018