# Copyright (C) 2017 The Meme Factory, Inc.  http://www.meme.com/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Karl O. Pinc <kop@meme.com>

'''
Transformation of the tuples which are returned by the parser
into "scan items", which are used to make decision trees.
'''

import re


# Constants
VERSION_RE_STR = r' v[1-9][0-9]*'
PAREN_RE_STR = (r' \((?! )'         # start with space open-paren
                                    # not followed by a space
                r'[a-zA-Z0-9;\ ]+'  # inside parens can be semicolon,
                                    # upper or lower case letters,
                                    # space, or digits
                r'\)(?<! )')        # parentheticals end with close paren
#                                   # not preceeded by a space
PAREN_RE = re.compile(PAREN_RE_STR)
REPEATED_PAREN_RE_STR = (r'(?:' +        # one or more parentheticals
                         PAREN_RE_STR +
                         r')+')          # (one or more parentheticals)


#
# Classes
#
class ScanItem(object):
    def __init__(self, type, pos):
        super(ScanItem, self).__init__()
        self.type = type
        self.pos = pos

    def __eq__(self, other):
        return self.type == other.type

    def __hash__(self):
        return hash(self.type)

    def matches(self, other):
        '''Returns a pair (match_len, report_items)

        match_len     The length of the data which matched
        report_items  "Stuff" to be reported about the match
        '''
        raise NotImplementedError()


class ValuedScanItem(ScanItem):
    def __init__(self, type, pos, value):
        super(ValuedScanItem, self).__init__(type, pos)
        self.value = value

    def __eq__(self, other):
        return (super(ValuedScanItem, self).__eq__(other) and
                self.value == other.value)

    def __hash__(self):
        return hash((self.type, self.value))


class StringScanItem(ValuedScanItem):
    def __init__(self, pos, value):
        super(StringScanItem, self).__init__('string', pos, value)
        self.value_len = len(value)

    def matches(self, other):
        if other != '' and self.value == other[0]:
            return (self.value_len, [])
        else:
            return (0, [])


class UserStrScanItem(ScanItem):
    TERMINAL_CHARS = r"[A-Za-z0-9,\#\.\'&,]"
    PATTERN_STR = (TERMINAL_CHARS +           # one terminal
                   r'(?:' +                   # then 0 or 1 of (
                   r'(?:' +                   # 0 or many of (
                   TERMINAL_CHARS +           # a terminal
                   r'|\-' +                   # or a dash
                   r'|(?: (?!\- ))' +         # or a space not followed by
                                              # a dash and a space
                   r')*' +                    # )
                   TERMINAL_CHARS +           # followed by one terminal
                   r')?')                     # )
    PATTERN = re.compile(PATTERN_STR)

    def __init__(self, pos, type='<user_str>'):
        super(UserStrScanItem, self).__init__(type[1:-1], pos)

    def matches(self, other):
        return match_or_empty(self.PATTERN, other)


class SpecialUserStrScanItem(UserStrScanItem):
    def __init__(self, pos, type, next_str):
        super(SpecialUserStrScanItem, self).__init__(pos, type)

        def special_matches(other):
            try:
                next_str_pos = other.index(next_str)
            except ValueError:
                next_str_pos = len(other)
            stripped = other[0:next_str_pos]
            return super(SpecialUserStrScanItem, self).matches(stripped)

        self.next_str = next_str
        self.matches = special_matches

    def __eq__(self, other):
        return (super(SpecialUserStrScanItem, self).__eq__(other) and
                isinstance(other, SpecialUserStrScanItem) and
                self.next_str == other.next_str)

    def __hash__(self):
        return hash((self.type, SpecialUserStrScanItem, self.next_str))


class HashScanItem(ScanItem):
    PATTERN_STR = r'[0-9]+'
    PATTERN = re.compile(PATTERN_STR)

    def __init__(self, pos):
        super(HashScanItem, self).__init__('hash', pos)

    def matches(self, other):
        return match_or_empty(self.PATTERN, other)


class EntityScanItem(UserStrScanItem):
    def __init__(self, pos):
        super(EntityScanItem, self).__init__(pos, '<entity>')


class DateScanItem(ScanItem):
    PATTERN_STR = r'[0-9]{4}-[0-9]{2}-[0-9]{2}'
    PATTERN = re.compile(PATTERN_STR)

    def __init__(self, pos):
        super(DateScanItem, self).__init__('date', pos)

    def matches(self, other):
        return match_or_empty(self.PATTERN, other)


class YearScanItem(ScanItem):
    PATTERN_STR = r'[0-9]{4}'
    PATTERN = re.compile(PATTERN_STR)

    def __init__(self, pos):
        super(YearScanItem, self).__init__('year', pos)

    def matches(self, other):
        return match_or_empty(self.PATTERN, other)


class Last4DigitsScanItem(ScanItem):
    PATTERN_STR = r'[0-9]{4}'
    PATTERN = re.compile(PATTERN_STR)

    def __init__(self, pos):
        super(Last4DigitsScanItem, self).__init__('last_4_digits', pos)

    def matches(self, other):
        return match_or_empty(self.PATTERN, other)


class VersionScanItem(ScanItem):
    PATTERN = re.compile(VERSION_RE_STR)

    def __init__(self, pos):
        super(VersionScanItem, self).__init__('version', pos)

    def matches(self, other):
        return match_or_empty(self.PATTERN, other)


class ParenScanItem(ValuedScanItem):
    def __init__(self, pos, value):
        super(ParenScanItem, self).__init__('paren', pos, value)
        pattern_str = ''.join([r'\ \(',
                               self.value[2:-1],
                               r'(; [a-zA-Z0-9,&\#\.\ ]+)?',
                               r'\)'])
        self.pattern = re.compile(pattern_str)

    def matches(self, other):
        match = self.pattern.match(other)
        if match is None:
            return (0, [])
        elif match.group(1):
            start = match.start()
            end = match.end()
            return (end - start, [other[start + 2: end - 1]])
        else:
            return (match.end() - match.start(), [])


class RepeatedParensScanItem(ScanItem):
    PATTERN = re.compile(REPEATED_PAREN_RE_STR)

    def __init__(self, pos):
        super(RepeatedParensScanItem, self).__init__('repeated_paren', pos)

    def matches(self, other):
        match = self.PATTERN.match(other)
        if match is None:
            return (0, [])
        else:
            return (match.end() - match.start(),
                    repeated_paren_content(other[match.start(): match.end()]))


class OptionalScanItem(ValuedScanItem):
    def __init__(self, pos, options):
        super(OptionalScanItem, self).__init__('optional', pos, options)

    # We should not be using "in" on optional scan items.  Revert to
    # the default behavior for objects just in case.
    __hash__ = None


#
# Functions
#

def match_or_empty(pattern, val):
    '''Return a "no match tuple" when re pattern does not match'''
    match = pattern.match(val)
    if match is None:
        return (0, [])
    else:
        return (len(match.group()), [])


def repeated_paren_content(matched):
    result = []
    while matched:
        match = PAREN_RE.match(matched)
        end = match.end()
        result.append(matched[2: end - 1])
        matched = matched[end:]
    return result


def tupllen(tupl):
    '''Get the length of a tuple returned by the parser'''
    typ = tupl[0]
    if typ == 'optional':
        # +2 for open and close braces not part of parser token
        return scanlen(tupl[1]) + 2
    elif typ == 'repeated_parens':
        return 0
    else:
        return len(tupl[1])


def scanlen(scan):
    '''Get the length of a scan's pattern components'''
    pos = 0
    for tupl in scan:
        pos += tupllen(tupl)
    return pos


def scan_to_specialuserstrscanitem(scan, pos):
    trailing_strs = []
    for next_tupl in scan[1:]:
        if next_tupl[0] != 'string':
            break
        trailing_strs.append(next_tupl[1])
    return SpecialUserStrScanItem(pos, scan[0][1], ''.join(trailing_strs))


def to_scan_items(scan, pos=0):
    '''Convert a parse scan to (singular) scanned items'''
    items = []
    tail = scan[:]
    while tail:
        items.extend(tuple_to_items(tail, pos))
        pos += tupllen(tail[0])
        tail = tail[1:]
    return items


def tuple_to_items(scan, pos):
    tupl = scan[0]
    typ = tupl[0]
    if typ == 'string':
        items = []
        for ch in tupl[1]:
            items.append(StringScanItem(pos, ch))
            pos += 1
    elif typ == 'user_str':
        if len(scan) > 1 and scan[1][0] == 'string':
            items = [scan_to_specialuserstrscanitem(scan, pos)]
        else:
            items = [UserStrScanItem(pos, tupl[1])]
    elif typ == 'hash':
        items = [HashScanItem(pos)]
    elif typ == 'entity':
        items = [EntityScanItem(pos)]
    elif typ == 'date':
        items = [DateScanItem(pos)]
    elif typ == 'year':
        items = [YearScanItem(pos)]
    elif typ == 'last_4_digits':
        items = [Last4DigitsScanItem(pos)]
    elif typ == 'version':
        items = [VersionScanItem(pos)]
    elif typ == 'paren':
        items = [ParenScanItem(pos, tupl[1])]
    elif typ == 'repeated_parens':
        items = [RepeatedParensScanItem(pos)]
    elif typ == 'optional':
        items = [OptionalScanItem(
            # +1 for open brace "[" which is not in parser token
            pos, to_scan_items(tupl[1], pos + 1))]
    else:
        raise ValueError("Unknown type of tuple '{0}'".format(tupl))

    return items
