# Copyright (C) 2015, 2018, 2020, 2021 The Meme Factory, Inc.
# http://www.karlpinc.com/

# This file is part of PGWUI_Bulk_Upload.
#
# This program is free software: you can redistribute it and/or
# modify it under the terms of the GNU Affero General Public License
# as published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License along with this program.  If not, see
# <http://www.gnu.org/licenses/>.
#
from __future__ import generator_stop

from pyramid.view import view_config
import attr
import logging
import os
import os.path
import pathlib
import psycopg2
import tempfile
import zipfile

import yaml

from pgwui_common.view import auth_base_view
import pgwui_core.exceptions as core_ex

from pgwui_core.core import (
    UploadNullFileWTForm,
    UploadEngine,
    ParameterExecutor,
    DataLineProcessor,
    UploadNullMixin,
    UploadDoubleFileForm,
    UploadData,
    escape_eol,
    is_checked,
)
from pgwui_upload_core.views.upload import (
    BaseTableUploadHandler,
    UploadCoreInitialPost,
    set_upload_response,
)

import pgwui_bulk_upload.exceptions as ex


log = logging.getLogger(__name__)


def map_description(filepath, relation):
    return f'Error uploading ({filepath}) into ({relation})'


def archive_path(path):
    '''Return, as text, the path of a file within the archive
    '''
    return os.path.join(*[os.fsdecode(part) for part in path.parts[-2:]])


class BulkUploadForm(UploadNullMixin, UploadDoubleFileForm):
    '''
    Acts like a dict, but with extra methods.

    Attributes:
      uh      The UploadHandler instance using the form

    Methods:
      read()  Load form from pyramid request object.
    '''
    def read(self):
        '''
        Read form data from the client
        '''
        # Read parent's data
        super().read()

    def write(self, result, errors):
        '''
        Produces the dict pyramid will use to render the form.
        '''
        response = super().write(result, errors)
        return super().write_response(response)


class SaveBulkLine(DataLineProcessor, ParameterExecutor):
    def __init__(self, ue, uh, insert_map):
        '''
        ue             UploadEngine instance
        uh             UploadHandler instance
        insert_map     Dict mapping file to insert statement used to insert
                       into db. (psycopg2 formatted for substituion)
        '''
        super().__init__(ue, uh)
        self.insert_map = insert_map

    def eat(self, bulk_data):
        '''
        Upload a line of data into the db.

        bulk_data  A tuple:
                   (UploadBulkData instance,
                    thunk returning UploadDataLine instance)
        '''
        (data, thunk) = bulk_data
        filepath = data.filepath
        try:
            udl = thunk()
            self.param_execute(self.insert_map[filepath], udl)
        except (core_ex.DataLineError, core_ex.DBError) as exp:
            relation = data.relation
            raise exp.color(map_description(filepath, relation),
                            filepath, relation)
        except core_ex.MultiError as exp:
            relation = data.relation
            for error in exp.errors:
                if isinstance(error, core_ex.UploadError):
                    error.color(map_description(filepath, relation),
                                filepath, relation)
            raise exp
        except psycopg2.DatabaseError as exp:
            relation = data.relation
            raise core_ex.DBDataLineError(udl, exp).color(
                map_description(filepath, relation),
                filepath, relation)


class UploadBulkData(UploadData):
    def __init__(self, fileo, file_fmt, null_data, null_rep,
                 path, relation, trim=True):
        '''
        fileo       Stream to uploaded file
        file_fmt    File format: CSV or TAB
        null_data   (boolean) Uploaded data contains nulls
        null_rep    Uploaded string representation of null
        path        pathlib path to file.
        relation    Relation, possibly schema qualified
        trim        (boolean) Trim leading and trailing whitespace?

        filepath    Path of file (zip_root relative)
        '''
        super().__init__(fileo, file_fmt, null_data, null_rep, trim=trim)
        self.path = path
        self.filepath = archive_path(path)
        self.relation = relation

    def _thunk(self):
        '''Get the thunk which returns the next udl
        '''
        # Reopen the file, now that it is time to upload it
        try:
            self.open_fileo(self.path.open('rb'))
        except OSError as exp:
            saved = exp

            def thunk():
                raise ex.CannotReReadError(self.filepath, self.lineno, saved)
            yield thunk
            self.lineno -= 1
            return   # shut down generator
        next(super()._thunk())             # skip header
        yield from super()._thunk()


@attr.s
class UploadDir():
    '''Uploaded archive directory

    It is iterable, consisting of SaveBulkLine instances
    '''
    uf = attr.ib()
    dentry = attr.ib()
    map_file = attr.ib()
    filedata = attr.ib(factory=list)

    def get_file_map(self, yaml_file):
        try:
            with yaml_file.open() as yaml_fh:
                return yaml.safe_load(yaml_fh)
        except yaml.YAMLError as exp:
            raise ex.BadMapfileError(
                self.uf['filename'], archive_path(yaml_file), exp)

    def check_tag(self, errors, yaml_file, count, file_map, tag,
                  required=True, string=True):
        '''Confirm that the tag exists and holds a string;
        remove the tag from the map
        '''
        try:
            if not isinstance(file_map[tag], str) and string:
                errors.append(ex.MustBeStringError(
                    archive_path(yaml_file), count, tag, file_map[tag]))
        except KeyError:
            if required:
                errors.append(ex.MissingFileMapTagError(
                    archive_path(yaml_file), count, tag))
        else:
            del file_map[tag]

    def extra_tags(self, errors, yaml_file, count, map_item, execp):
        '''Confirm that there are no unrecognized tags
        '''
        for key in map_item:
            errors.append(execp(archive_path(yaml_file), count, key))

    def validate_key_existance(self, errors, yaml_file, count, file_map):
        '''Confirm that a file_map contains the right tags, if not
        add to errors
        '''
        my_file_map = file_map.copy()
        self.check_tag(errors, yaml_file, count, my_file_map, 'file')
        self.check_tag(errors, yaml_file, count, my_file_map, 'relation')
        self.check_tag(
            errors, yaml_file, count, my_file_map, 'trim',
            required=False, string=False)
        self.extra_tags(errors, yaml_file, count, my_file_map,
                        ex.ExtraFileMapTagError)

    def validate_values(self, errors, yaml_file, count, file_map):
        '''Confirm that a file_map contains the right tag values, if not
        add to errors
        (Only used for values that are not a string.)
        '''
        if 'trim' in file_map:
            value = file_map['trim']
            if not isinstance(value, bool):
                errors.append(ex.BadTrimValueError(
                    archive_path(yaml_file), count, value))

    def validate_file_map(self, errors, yaml_file, count, file_map):
        '''Confirm that a file_map contains the right tags, if not
        add to errors
        '''
        self.validate_key_existance(errors, yaml_file, count, file_map)
        self.validate_values(errors, yaml_file, count, file_map)

    def validate_map_item(self, errors, yaml_file, count, map_item):
        '''Confirm that a map_item is a dict with the right tags, if not
        add to errors
        '''
        if not isinstance(map_item, dict):
            errors.append(ex.BadMapListEntryError(
                archive_path(yaml_file), map_item))
            return
        my_map_item = map_item.copy()
        my_map_item.pop('file_map', None)
        self.extra_tags(errors, yaml_file, count, my_map_item,
                        ex.ExtraMapListTagError)
        if 'file_map' not in map_item:
            errors.append(ex.MissingMapListTagError(
                archive_path(yaml_file), count, 'file_map'))
            return
        self.validate_file_map(errors, yaml_file, count, map_item['file_map'])

    def detect_duplicates(self, errors, yaml_file, map_list):
        '''Add to errors if the same file is uploaded into the same table
        more than once
        '''
        pairs = dict()
        count = 1
        for map_item in map_list:
            file_map = map_item['file_map']
            file = file_map['file']
            relation = file_map['relation']
            pair = (file, relation)
            if pair in pairs:
                errors.append(ex.DuplicateMappingError(
                    archive_path(yaml_file), pairs[pair],
                    count, file, relation))
            else:
                pairs[pair] = count
            count += 1

    def validate_map_list(self, yaml_file, map_list):
        '''Confirm that the map list contains the right YAML structure
        '''
        errors = []
        count = 1
        for map_item in map_list:
            self.validate_map_item(errors, yaml_file, count, map_item)
            count += 1
        self.detect_duplicates(errors, yaml_file, map_list)
        if errors:
            raise core_ex.MultiError(errors)

    def validate_mapfile(self, yaml_file, file_map):
        '''Confirm that the file map has the right YAML structure
        '''
        if not isinstance(file_map, dict) or 'map_list' not in file_map:
            raise ex.NoMapListError(archive_path(yaml_file))

        map_list = file_map['map_list']
        if not isinstance(map_list, list):
            raise ex.BadMapListError(archive_path(yaml_file))

        self.validate_map_list(yaml_file, map_list)

    def collect_files(self, dir_name):
        '''Return a list of all the file names, but the map file, in dir_name
        '''
        map_file = self.map_file
        file_names = []
        with os.scandir(dir_name) as direntry:
            for file in direntry:
                name = os.fsdecode(file.name)
                if name != map_file:
                    file_names.append(name)
        return file_names

    def match_map_to_fs(self, dir_name, yaml_file, file_map):
        '''Produce errors if the files in the dir do not match the names
        in the file map
        '''
        errors = []
        file_names = set(self.collect_files(dir_name))
        item_files = set([item['file_map']['file']
                          for item in file_map['map_list']])

        extra_items = item_files - file_names
        for item in extra_items:
            errors.append(ex.MissingFileError(
                archive_path(yaml_file), archive_path(dir_name / item)))

        extra_files = file_names - item_files
        for extrafile in extra_files:
            errors.append(ex.ExtraFileError(
                archive_path(yaml_file), archive_path(dir_name / extrafile)))

        if errors:
            raise core_ex.MultiError(errors)

    def _trim(self, fmap):
        '''Should the file be trimmed?
        '''
        if 'trim' in fmap:
            return fmap['trim']
        else:
            return self.uf['trim_upload']

    def get_filedata(self, dir_name, file_map):
        '''Return a list of UploadData instances or raise an error

        The list is in the order given in the map file.
        '''
        uf = self.uf
        filedata = []
        errors = []
        for item in file_map['map_list']:
            fmap = item['file_map']
            name = dir_name / fmap['file']
            try:
                fh = name.open('rb')
            except OSError as exp:
                errors.append(ex.CannotReadError(archive_path(name), exp))
            else:
                try:
                    filedata.append(UploadBulkData(fh,
                                                   uf['upload_fmt'],
                                                   uf['upload_null'],
                                                   uf['null_rep'],
                                                   name,
                                                   fmap['relation'],
                                                   trim=self._trim(fmap)))
                except core_ex.PGWUIError as exp:
                    relation = fmap['relation']
                    errors.append(exp.color(map_description(name, relation),
                                            name, relation))
        if errors:
            raise core_ex.MultiError(errors)
        return filedata

    def open(self, dir_name):
        '''
        Initialize object, return number of files to upload
        dir_name    Path of directory to upload, a pathlib.Path
        '''
        yaml_file = dir_name / self.map_file
        file_map = self.get_file_map(yaml_file)
        self.validate_mapfile(yaml_file, file_map)
        self.match_map_to_fs(dir_name, yaml_file, file_map)
        self.filedata = self.get_filedata(dir_name, file_map)
        return len(self.filedata)

    def __iter__(self):
        for data in self.filedata:
            for thunk in data:
                yield lambda: (data, thunk)

    def fileinfo(self):
        for fileinfo in self.filedata:
            yield fileinfo

    def stats(self):
        for fileinfo in self.filedata:
            yield (fileinfo.filepath, fileinfo.lineno, fileinfo.relation)


@attr.s
class UploadArchive():
    '''Uploaded archive

    It is iterable, consisting of UploadDir instances
    '''
    lineno = attr.ib(default=0)
    updirs = attr.ib(factory=list)

    def open(self, uf, zip_root, map_file):
        '''
        uf          The upload form
        zip_root    Root of unarchived upload, a pathlib.Path
        map_file    Name of the file which maps tables to files
        '''
        err_tups = []
        with os.scandir(zip_root) as archive:
            for dentry in archive:
                updir = UploadDir(uf, dentry, map_file)
                try:
                    # Every file consumes a header line
                    self.lineno += updir.open(zip_root / dentry.name)
                except ex.SetupError as exp:
                    err_tups.append((dentry.name, exp))
                except core_ex.MultiError as exp:
                    err_tups.extend([(dentry.name, err) for err in exp.errors])
                else:
                    self.updirs.append(updir)
        if err_tups:
            err_tups.sort(key=lambda tup: tup[0])
            raise core_ex.MultiError([tup[1] for tup in err_tups])

        self.updirs.sort(key=lambda elmt: os.fsdecode(elmt.dentry.name))

    def __iter__(self):
        for updir in self.updirs:
            for thunk in updir:
                self.lineno += 1
                yield thunk

    def filedata(self):
        for updir in self.updirs:
            for info in updir.fileinfo():
                yield info

    def stats(self):
        for updir in self.updirs:
            for stat in updir.stats():
                yield stat


class BulkTableUploadHandler(BaseTableUploadHandler):
    '''
    Attributes:
      request       A pyramid request instance
      uf            A GCUploadForm instance
      session       A pyramid session instance
      ue
      cur
      map_file
      temp_dir
    '''
    def __init__(self, request):
        '''
        request A pyramid request instance
        '''
        super().__init__(request)
        self.temp_dir = None

    def cleanup(self):
        if self.temp_dir:
            self.temp_dir.cleanup()
        super().cleanup()

    def make_form(self):
        '''
        Make the upload form needed by this handler.
        '''
        return BulkUploadForm().build(
            self, fc=UploadNullFileWTForm,
            ip=UploadCoreInitialPost().set_component('pgwui_bulk_upload'))

    def get_data(self):
        '''
        Return an UploadData instance, with flags set as desired.
        '''
        self.data = UploadArchive()

    def validate_mapfile(self, updir, errors):
        '''Add an error to errors if the map file does not exist
        '''
        map_file_path = pathlib.Path(updir) / self.map_file
        if not map_file_path.exists():
            errors.append(ex.NoMapfileError(
                self.uf['filename'], archive_path(map_file_path)))

    def validate_updir(self, updir):
        '''Add to the list of errors from a directory of an uploaded archive
        '''
        errors = []
        with os.scandir(updir) as updir_entry:
            for entry in updir_entry:
                if not entry.is_file():
                    errors.append(ex.NotAFileError(
                        self.uf['filename'],
                        archive_path(pathlib.Path(entry))))
        self.validate_mapfile(updir, errors)
        return errors

    def validate_archive_structure(self):
        '''Return list of errors of the archive content's structure,
        sorted by top-level directory name
        '''
        self.map_file = (self.request.registry.settings['pgwui']
                         ['pgwui_bulk_upload']['map_file'])
        err_tups = []
        count = 0
        with os.scandir(self.zip_root) as archive:
            for entry in archive:
                if entry.is_dir():
                    err_tups.extend([
                        (entry.name, error) for error in
                        self.validate_updir(entry)])
                else:
                    err_tups.append(
                        (entry.name,
                         ex.NotADirectoryError(
                             self.uf['filename'], os.fsdecode(entry.name))))
                count += 1
        if count == 0:
            return [ex.EmptyArchiveError(self.uf['filename'])]
        err_tups.sort(key=lambda tup: tup[0])
        return [tup[1] for tup in err_tups]

    def unarchive(self):
        '''Unarchive the zip file.  Return a list of errors.
        '''
        fh = self.uf['localfh']
        try:
            with zipfile.ZipFile(fh) as zfile:
                zfile.extractall(self.zip_root)
        except zipfile.BadZipFile as exp:
            return [ex.CannotUnarchiveError(self.uf['filename'], exp)]
        return self.validate_archive_structure()

    def validate_archive(self):
        '''Validate the uploaded archive; un-archive it and save
        a path to the decompressed archive.  Return a list of errors.
        '''
        if self.uf['filename'] == '':
            return []   # Don't bother, there is no uploaded file
        fh = self.uf['localfh']
        if not zipfile.is_zipfile(fh):
            return [ex.NotAZipfileError(self.uf['filename'])]
        self.temp_dir = tempfile.TemporaryDirectory()
        self.zip_root = pathlib.Path(self.temp_dir.name)
        return self.unarchive()

    def val_input(self):
        '''
        Validate input needed beyond that required to connect to the db.

        Returns:
          A list of PGWUIError instances
        '''
        errors = super().val_input()

        errors.extend(self.validate_archive())

        return errors

    def quote_columns(self):
        return super().quote_columns(
            self.request.registry.settings['pgwui']['pgwui_bulk_upload'])

    def insert_map(self):
        '''Return a map (dict) of file paths to insert statements
        '''
        quotecols = self.quote_columns()
        column_quoter = self.get_column_quoter(quotecols)

        in_trans = True
        i_map = dict()
        errors = []
        for fileinfo in self.data.filedata():
            try:
                i_map[fileinfo.filepath] = self.build_insert_stmt(
                    fileinfo, fileinfo.relation, quotecols, column_quoter)
            except core_ex.PGWUIError as exp:
                fileinfo.lineno = 0   # Don't report number of lines in file
                exp.color(
                    map_description(fileinfo.filepath, fileinfo.relation),
                    fileinfo.filepath, fileinfo.relation)
                errors.append(exp)
                if in_trans:
                    # In order to continue to query the db after a db error
                    # the current transaction must be rolled back.
                    # Because we know we're going to abort, don't start a
                    # new transaction.
                    try:
                        self.cur.execute('ROLLBACK;')
                    except psycopg2.DatabaseError as exp:
                        err = ex.CannotRollbackError(exp).color(
                            map_description(fileinfo.filepath,
                                            fileinfo.relation),
                            fileinfo.filepath, fileinfo.relation)
                        errors.append(err)
                    in_trans = False
            finally:
                # Limit number of open files, close the file handle until it
                # is time to read the file
                fileinfo.close_fileo()
        if errors:
            raise core_ex.MultiError(errors)

        return i_map

    def factory(self, ue):
        '''Make a db loader function from an UploadEngine.

        Input:

        Side Effects:
        Yes, lots.
        '''
        self.ue = ue
        self.cur = ue.cur

        self.data.open(self.uf, self.zip_root, self.map_file)

        return SaveBulkLine(self.ue, self, self.insert_map())

    def stats(self):
        if self.data:
            return self.data.stats()
        return []


def statmap(stats):
    '''Return map of file path to (lines, relation) tuple
    '''
    smap = dict()
    for stat in stats:
        smap[stat[0]] = stat[1:]
    return smap


def inserted_rows(stats, response):
    '''Return the number of rows inserted
    '''
    if stats:
        return response['lines'] - len(stats)
    return 0


def relation_stats(stats, quote_columns):
    '''Produce summaries per relation from the stats
    '''
    relations = dict()
    for (path, lines, relation) in stats:
        if quote_columns:
            key = relation.lower()
        else:
            key = relation
        count = relations.get(key, 0)
        relations[key] = count + lines - 1
    items = list(relations.items())
    items.sort(key=lambda x: x[0])
    return items


def log_success(response, stats, i_rows, r_stats):
    if is_checked(response['csv_checked']):
        upload_fmt = 'CSV'
    else:
        upload_fmt = 'TAB'

    msg = [('Successful bulk upload: '
            f"Archive file ({response['filename']}): "
            f"DB {response['db']}: "
            f"Inserted a total of {i_rows} rows: "
            f'From {len(stats)} files: ')]
    if len(stats) != len(r_stats):
        # There were multiple files loaded into the same relation
        for (relation, count) in r_stats:
            msg.append(
                f'Inserted {count} rows into relation ({relation}): ')
    for (path, lines, relation) in stats:
        msg.append(
            f'Uploaded the {lines} lines of file ({path}) into ({relation}): ')
    msg.append(f'Format {upload_fmt}: '
               f"Upload Null {is_checked(response['upload_null'])}: "
               f"Null Rep ({escape_eol(response['null_rep'])}): "
               f"Trim {is_checked(response['trim_upload'])}: "
               f"By user {response['user']}")

    log.info(''.join(msg))


def analyze_results(uh, response):
    stats = list(uh.stats())
    response['stats'] = stats
    response['statmap'] = statmap(stats)
    i_rows = inserted_rows(stats, response)
    response['inserted_rows'] = i_rows
    r_stats = relation_stats(stats, uh.quote_columns())
    response['relation_stats'] = r_stats

    if response['db_changed']:
        log_success(response, stats, i_rows, r_stats)

    return response


@view_config(route_name='pgwui_bulk_upload',
             renderer='pgwui_bulk_upload:templates/bulk_upload.mak')
@auth_base_view
def bulk_upload_view(request):

    uh = BulkTableUploadHandler(request).init()
    response = UploadEngine(uh).run()

    set_upload_response('pgwui_bulk_upload', request, response)

    return analyze_results(uh, response)
