# -------------------------------------------------------------------------
#     This file is part of mMass - the spectrum analysis tool for MS.
#     Copyright (C) 2005-07 Martin Strohalm <mmass@biographics.cz>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 2 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE in the
#     main directory of the program
# -------------------------------------------------------------------------

# Function: Count peptide fragments and match to peaklist.

# load modules
from nucleus import commfce

class mFragCount:
    """Count peptide fragments and match to given peaklist"""

    # ----
    def __init__(self, config):
        self.config = config
        self.ctrlData = {}
    # ----


    # ----
    def getMainFragments(self, parsedSeq):
        """ Count main peptide fragments. """

        length = len(parsedSeq)
        fragments = []

        # set common elements
        C = self.config.elem['C'][self.ctrlData['masstype']]
        H = self.config.elem['H'][self.ctrlData['masstype']]
        O = self.config.elem['O'][self.ctrlData['masstype']]
        N = self.config.elem['N'][self.ctrlData['masstype']]
        
        # set terminus
        nTermMass = H
        cTermMass = O + H

        # set masstype (2=mono, 3=aver in sequence dic)
        if self.ctrlData['masstype'] == 'mmass':
            massType = 1
        else:
            massType = 2

        # count total mass and get clean sequence
        totalMass = 0
        sequence = ''
        for item in parsedSeq:
            totalMass += item[massType]
            sequence += item[0]

        # initialize fragments
        fragA0 = - (C + H + O) + H
        fragB0 = - H + H
        fragC0 = + (N + 2*H) + H
        fragX0 = totalMass - fragA0
        fragY0 = totalMass + 2*H
        fragZ0 = totalMass - N - H + H

        # count fragments
        for x in range(length):

            # get N amino acid
            aa = parsedSeq[x][0]

            # mark modified amino acids
            if len(parsedSeq[x]) > 3:
                aa +='*'

            # count fragment: immonia
            fragIm = parsedSeq[x][massType] - (C + O) + H
            if x == 0:
                fragIm -= nTermMass
            if x == length-1:
                fragIm -= cTermMass

            # count fragments: a
            fragA0 += parsedSeq[x][massType]
            fragA1 = fragA0 - (N + 3*H)
            fragA2 = fragA0 - (H*2 + O)
            fragA3 = (fragA0 + H)/2
            fragA4 = (fragA1 + H)/2
            fragA5 = (fragA2 + H)/2

            # count fragments: b
            fragB0 += parsedSeq[x][massType]
            fragB1 = fragB0 - (N + 3*H)
            fragB2 = fragB0 - (H*2 + O)
            fragB3 = (fragB0 + H)/2
            fragB4 = (fragB1 + H)/2
            fragB5 = (fragB2 + H)/2

            # count fragment: c
            fragC0 += parsedSeq[x][massType]
            fragC1 = (fragC0 + H)/2

            # count fragment: x
            if x > 0: fragX0 -= parsedSeq[x-1][massType]
            fragX1 = (fragX0 + H)/2

            # count fragments: y
            if x > 0: fragY0 -= parsedSeq[x-1][massType]
            fragY1 = fragY0 - (N + 3*H)
            fragY2 = fragY0 - (H*2 + O)
            fragY3 = (fragY0 + H)/2
            fragY4 = (fragY1 + H)/2
            fragY5 = (fragY2 + H)/2

            # count fragment: z
            if x > 0: fragZ0 -= parsedSeq[x-1][massType]
            fragZ1 = (fragZ0 + H)/2

            # filter ions
            seqN = sequence[:x+1]
            seqC = sequence[x:]
            ionFilter = self.filterMainFragments(seqN, seqC)

            # add to fragments dic
            fragments.append({'amino':aa,
                            'filter':ionFilter,
                            'a0':[fragA0,''], 'a1':[fragA1,''], 'a2':[fragA2,''],
                            'a3':[fragA3,''], 'a4':[fragA4,''], 'a5':[fragA5,''],
                            'b0':[fragB0,''], 'b1':[fragB1,''], 'b2':[fragB2,''],
                            'b3':[fragB3,''], 'b4':[fragB4,''], 'b5':[fragB5,''],
                            'y0':[fragY0,''], 'y1':[fragY1,''], 'y2':[fragY2,''],
                            'y3':[fragY3,''], 'y4':[fragY4,''], 'y5':[fragY5,''],
                            'c0':[fragC0,''], 'x0':[fragX0,''], 'z0':[fragZ0,''],
                            'c1':[fragC1,''], 'x1':[fragX1,''], 'z1':[fragZ1,''],
                            'im':[fragIm,'']})

        return fragments
    # ----


    # ----
    def getInternalFragments(self, parsedSeq):
        """ Count internal fragments. """

        length = len(parsedSeq)
        fragments = []

        # set common elements
        C = self.config.elem['C'][self.ctrlData['masstype']]
        H = self.config.elem['H'][self.ctrlData['masstype']]
        O = self.config.elem['O'][self.ctrlData['masstype']]
        N = self.config.elem['N'][self.ctrlData['masstype']]

        # set masstype (2=mono, 3=aver in sequence dic)
        if self.ctrlData['masstype'] == 'mmass':
            massType = 1
        else:
            massType = 2

        # generate fragments
        for x in range(1, length-1):
            mass = parsedSeq[x][massType]

            # get sequence
            sequence = parsedSeq[x][0]

            # get modifications
            modifications = parsedSeq[x][3:]

            # get modifications
            for y in range(x+1, length-1):
                mass += parsedSeq[y][massType]

                # get sequence
                sequence += parsedSeq[y][0]

                # get modifications
                modifications.extend(parsedSeq[y][3:])

                # count fragments
                position = '[%d-%d]' % (x+1, y+1)
                int0 = mass + H
                int1 = int0 - (C + O)
                int2 = int0 - (N + 3*H)
                int3 = int0 - (2*H + O)

                # filter ions
                ionFilter = self.filterInternalFragments(sequence)

                # add to dic
                formatedSequence = sequence + self.formatModifs(modifications)
                fragments.append({'filter':ionFilter, 'pos':position, 'seq':formatedSequence, 'int0':[int0,''], 'int1':[int1,''], 'int2':[int2,''], 'int3':[int3,'']})

        return fragments
    # ----


    # ----
    def filterMainFragments(self, seqN, seqC):
        """ Filter ions which are teoretical only. """

        ionFilter = ''

        # filter -H2O
        if ('S' not in seqN) and ('T' not in seqN) and ('E' not in seqN) and ('D' not in seqN):
            ionFilter += 'a2;b2;a5;b5;'
        if ('S' not in seqC) and ('T' not in seqC) and ('E' not in seqC) and ('D' not in seqC):
            ionFilter += 'y2;y5;'

        # filter -NH3
        if ('R' not in seqN) and ('K' not in seqN) and ('Q' not in seqN) and ('N' not in seqN):
            ionFilter += 'a1;b1;a4;b4;'
        if ('R' not in seqC) and ('K' not in seqC) and ('Q' not in seqC) and ('N' not in seqC):
            ionFilter += 'y1;y4;'

        # filter 2+
        if ('R' not in seqN) and ('K' not in seqN) and ('Q' not in seqN) and ('N' not in seqN):
            ionFilter += 'a3;a4;a5;b3;b4;b5;c1;'
        if ('R' not in seqC) and ('K' not in seqC) and ('Q' not in seqC) and ('N' not in seqC):
            ionFilter += 'x1;y3;y4;y5;z1;'

        return ionFilter
    # ----


    # ----
    def filterInternalFragments(self, sequence):
        """ Filter ions which are teoretical only. """

        ionFilter = ''

        # filter -NH3
        if ('R' not in sequence) and ('K' not in sequence) and ('Q' not in sequence) and ('N' not in sequence):
            ionFilter += 'int2;'

        # filter -H2O
        if ('S' not in sequence) and ('T' not in sequence) and ('E' not in sequence) and ('D' not in sequence):
            ionFilter += 'int3;'

        return ionFilter
    # ----


    # ----
    def formatModifs(self, modifs):
        """ Count and format list of modifications. """

        # format modifications
        formatedMod = ''
        if modifs != []:

            # count modifs
            modCounter = {}
            for mod in modifs:
                if mod not in modCounter:
                    modCounter[mod] = 1
                else:
                    modCounter[mod] += 1

            # format modifs
            for mod in modCounter:
                formatedMod += '; ' + str(modCounter[mod]) + 'x ' + str(mod)
            formatedMod = ' (' + formatedMod[2:] + ')'

        return formatedMod
    # ----


    # ----
    def matchDataToPeaklist(self, peakList, fragments, countedIons):
        """ Match peaklist data with the list of generated fragment ions. """

        matched = False
        tolerance = self.ctrlData['tolerance']
        errorType = self.ctrlData['errortype']

        for row in range(len(fragments)):
            for ion in fragments[row]:
                if (ion in countedIons) and (ion not in ('filter', 'amino', 'pos', 'seq')):
                    match = ''

                    # discard some ions
                    if (row == 0 and ion in 'y0;y1;y2;x0;z0;y3;y4;y5;x1;z1;') \
                        or (row == len(fragments)-1 and not ion in 'im;y0;y1;y2;x0;z0;y3;y4;y5;x1;z1;int0;int1;int2;int3'):
                        continue

                    # ion filter
                    if ion in fragments[row]['filter'] and not self.config.cfg['mfrag']['matchfiltered']:
                        fragments[row][ion][1] = match
                        continue

                    # count tolerance from ion mass and error
                    massTolerance = commfce.countTolerance(fragments[row][ion][0], tolerance, errorType)

                    # check peaklist
                    for peak in range(len(peakList)):
                        if (fragments[row][ion][0] - massTolerance) <= peakList[peak][0] <= (fragments[row][ion][0] + massTolerance):
                            match = str(peak)+';'
                            matched = True
                    fragments[row][ion][1] = match

        return fragments, matched
    # ----


    # ----
    def getMatchInfo(self, peaklist, mainFragments, internalFragments, sequence, errorType, digits):
        """ Get match-info for each peak in main peaklist. """

        data={}
        data['params'] = []
        data['errors'] = []
        data['hidden'] = {}

        matchedPeaks = []
        errorList = []

        matchedIons = {'a0':'', 'b0':'', 'c0':'', 'x0':'', 'y0':'', 'z0':'', 'int0':0}
        for ion in matchedIons:
            if not ion in self.ctrlData['mainions'] \
                and not ion in self.ctrlData['internalions']:
                matchedIons[ion] = "Not counted"

        # main fragments
        for row in range(len(mainFragments)):
            # select ion
            for ion in mainFragments[row]:
                # matched and not filtered ion only
                if ion in self.ctrlData['mainions'] \
                    and mainFragments[row][ion][1] != '' \
                    and (ion not in mainFragments[row]['filter'] \
                        or self.config.cfg['mfrag']['matchfiltered']):

                    # get error list and info
                    peaks = mainFragments[row][ion][1].split(';')
                    for peakIndex in peaks:
                        if peakIndex != '':

                            # get peak
                            if peakIndex not in matchedPeaks:
                                matchedPeaks.append(peakIndex)

                            # get info
                            peakIndex = int(peakIndex)
                            ionMass = mainFragments[row][ion][0]
                            peakMass = peaklist[peakIndex][0]

                            # get error
                            errorList.append([peakMass, peakMass-ionMass])

                            # get ion name
                            if ion == 'x0' or ion == 'y0' or ion == 'z0':
                                ionIndex = len(mainFragments) - row
                                matchedIons[ion] += ion[0] + str(ionIndex) + '; '
                            elif ion == 'a0' or ion == 'b0' or ion == 'c0':
                                ionIndex = row + 1
                                matchedIons[ion] += ion[0] + str(ionIndex) + '; '

        # internal fragments
        for peptide in internalFragments:
            # select ion
            for ion in peptide:
                # matched and not filtered ion only
                if ion in self.ctrlData['internalions'] \
                    and not ion in ('seq', 'pos', 'filter') \
                    and not ion in peptide['filter'] \
                    and peptide[ion][1] != '':

                    # get error list and info
                    peaks = peptide[ion][1].split(';')
                    for peakIndex in peaks:
                        if peakIndex != '':

                            # get peak
                            if peakIndex not in matchedPeaks:
                                matchedPeaks.append(peakIndex)

                            # get info
                            peakIndex = int(peakIndex)
                            ionMass = peptide[ion][0]
                            peakMass = peaklist[peakIndex][0]

                            # get error
                            errorList.append([peakMass, peakMass-ionMass])

                            # get ion count
                            if ion == 'int0':
                                matchedIons['int0'] += 1

        # append data
        data['params'].append(['Tolerance: ', str(self.ctrlData['tolerance']) + ' ' + self.ctrlData['errortype']])
        data['params'].append(['Peaks in peaklist: ', str(len(peaklist))])
        data['params'].append(['Matched peaks: ', str(len(matchedPeaks))])
        data['params'].append(['Missed peaks: ', str(len(peaklist) - len(matchedPeaks))])
        data['params'].append(['Sequence length: ', str(len(sequence))])
        data['params'].append(['Matched fragments (a): ', matchedIons['a0']])
        data['params'].append(['Matched fragments (b): ', matchedIons['b0']])
        data['params'].append(['Matched fragments (c): ', matchedIons['c0']])
        data['params'].append(['Matched fragments (x): ', matchedIons['x0']])
        data['params'].append(['Matched fragments (y): ', matchedIons['y0']])
        data['params'].append(['Matched fragments (z): ', matchedIons['z0']])
        data['params'].append(['Matched fragments (int.): ', str(matchedIons['int0'])])
        data['errors'] = errorList
        data['hidden']['errortype'] = errorType

        return data
    # ----
