Wat is een intuïtieve uitleg van de Maximisatietechniek van de verwachting?

Verwachting Maximalisatie (EM) is een soort probabilistische methode om gegevens te classificeren. Corrigeer me als ik het mis heb of het geen classifier is.

Wat is een intuïtieve verklaring van deze EM-techniek? Wat is expectationhier en wat is maximized?


1, Autoriteit 100%

Opmerking: de code achter dit antwoord is te vinden hier .


Stel dat we een aantal gegevens hebben gesampled van twee verschillende groepen, rood en blauw:

Hier kunnen we zien welk gegevenspunt behoort tot de rode of blauwe groep. Dit maakt het gemakkelijk om de parameters die elke groep kenmerken te vinden. Het gemiddelde van de Rode Groep is bijvoorbeeld ongeveer 3, het gemiddelde van de Blue Group is ongeveer 7 (en we konden de exacte middelen vinden als we wilden).

Dit is, in het algemeen gesproken, bekend als Maximale waarschijnlijkheidsschatting . Gezien sommige gegevens berekenen we de waarde van een parameter (of parameters) die het beste uitlegt dat gegevens.

Stel je nu voor dat we niet kunnen zien welke waarde werd bemonsterd uit welke groep. Alles ziet er purper uit naar ons:

Hier hebben we de wetenschap dat er Two -groepen van waarden zijn, maar we weten niet welke groep een bepaalde waarde behoort.

Kunnen we nog steeds de middelen schatten voor de rode groep en de blauwe groep die deze gegevens het best passen?

Ja, vaak kunnen we! Verwachting Maximalisatie geeft ons een manier om het te doen. Het algemene idee achter het algoritme is dit:

  1. Start met een eerste schatting van wat elke parameter zou kunnen zijn.
  2. Bereken de waarschijnlijkheid die elke parameter het gegevenspunt produceert.
  3. Bereken gewichten voor elk gegevenspunt dat aangeeft of het nog rood of meer blauw is op basis van de waarschijnlijkheid dat het wordt geproduceerd door een parameter. Combineer de gewichten met de gegevens (verwachting ).
  4. Bereken een betere schatting voor de parameters met behulp van de aangepaste gegevens (maximalisatie ).
  5. Herhaal de stappen 2 tot 4 totdat de parameterschatting convergeert (het proces stopt met het produceren van een andere schatting).

Deze stappen hebben nog meer uitleg nodig, dus ik loop het hierboven beschreven probleem door.

Voorbeeld: het schatten van gemiddelde en standaardafwijking

Ik gebruik Python in dit voorbeeld, maar de code moet redelijk eenvoudig te begrijpen zijn als u niet bekend bent met deze taal.

Stel dat we twee groepen, rood en blauw hebben, met de waarden gedistribueerd als in de bovenstaande afbeelding. Specifiek bevat elke groep een waarde die is getrokken uit een normale distributie met de volgende parameters:

import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue))) # for later use...

Hier is een afbeelding van deze rode en blauwe groepen opnieuw (om u te redden van het moeten scrollen):

Wanneer we de kleur van elk punt kunnen zien (d.w.z. welke groep het behoort), is het heel eenvoudig om de gemiddelde en standaarddeviatie voor elke groep te schatten. We passeren gewoon de rode en blauwe waarden aan de ingebouwde functies in Numpy. Bijvoorbeeld:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

Maar wat als we de kleuren van de punten niet kunnen zien? Dat wil zeggen, in plaats van rood of blauw is elk punt paars gekleurd.

Om te proberen de parameters voor gemiddelde en standaarddeviatie voor de rode en blauwe groepen te achterhalen, kunnen we verwachtingsmaximalisatie gebruiken.

Onze eerste stap (stap 1hierboven) is om de parameterwaarden te raden voor het gemiddelde en de standaarddeviatie van elke groep. We hoeven niet intelligent te raden; we kunnen alle nummers kiezen die we leuk vinden:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

Deze parameterschattingen produceren belcurven die er als volgt uitzien:

Dit zijn slechte schattingen. Beide middelen (de verticale stippellijnen) lijken ver verwijderd van elk soort “midden” voor bijvoorbeeld zinvolle groepen punten. We willen deze schattingen verbeteren.

De volgende stap (stap 2) is het berekenen van de waarschijnlijkheid dat elk gegevenspunt verschijnt onder de huidige parameterschattingen:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Hier hebben we eenvoudig elk gegevenspunt in de kansdichtheidsfunctiegezet voor een normale verdeling met behulp van onze huidige schattingen van het gemiddelde en de standaarddeviatie voor rood en blauw. Dit vertelt ons bijvoorbeeld dat met onze huidige schattingen het datapunt op 1.761 veelmeer kans heeft om rood (0,189) te zijn dan blauw (0.00003).

Voor elk gegevenspunt kunnen we deze twee waarschijnlijkheidswaarden omzetten in gewichten (stap 3) zodat ze als volgt optellen tot 1:

likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Met onze huidige schattingen en onze nieuw berekende gewichten kunnen we nu nieuweschattingen berekenen voor het gemiddelde en de standaarddeviatie van de rode en blauwe groepen (stap 4) .

We berekenen tweemaal het gemiddelde en de standaarddeviatie met behulp van allegegevenspunten, maar met de verschillende wegingen: een keer voor de rode gewichten en een keer voor de blauwe gewichten.

Het belangrijkste van intuïtie is dat hoe groter het gewicht van een kleur op een gegevenspunt, hoe meer het gegevenspunt de volgende schattingen voor de parameters van die kleur beïnvloedt. Dit heeft als effect dat de parameters in de goede richting worden “getrokken”.

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").
    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").
    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.
    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

We hebben nieuwe schattingen voor de parameters. Om ze weer te verbeteren, kunnen we teruggaan naar stap 2 en het proces herhalen. We doen dit totdat de schattingen convergeren, of nadat een aantal iteraties is uitgevoerd (stap 5).

Voor onze gegevens zien de eerste vijf iteraties van dit proces er als volgt uit (recente iteraties zien er sterker uit):

We zien dat de gemiddelden al convergeren op sommige waarden, en de vormen van de curven (gestuurd door de standaarddeviatie) worden ook stabieler.

Als we 20 herhalingen doorgaan, krijgen we het volgende:

Het EM-proces is geconvergeerd naar de volgende waarden, die erg dicht bij de werkelijke waarden blijken te liggen (waar we de kleuren kunnen zien – geen verborgen variabelen):

         | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

In de bovenstaande code is het je misschien opgevallen dat de nieuwe schatting voor de standaarddeviatie is berekend met behulp van de schatting van de vorige iteratie voor het gemiddelde. Uiteindelijk maakt het niet uit of we eerst een nieuwe waarde voor het gemiddelde berekenen, omdat we alleen de (gewogen) variantie van waarden rond een centraal punt vinden. We zullen de schattingen voor de parameters nog steeds zien convergeren.


Antwoord 2, autoriteit 26%

EM is een algoritme voor het maximaliseren van een waarschijnlijkheidsfunctie wanneer sommige variabelen in uw model niet worden waargenomen (d.w.z. wanneer u latente variabelen heeft).

Je zou je kunnen afvragen, als we alleen maar proberen een functie te maximaliseren, waarom gebruiken we dan niet gewoon de bestaande machinerie om een ​​functie te maximaliseren. Als je dit probeert te maximaliseren door afgeleiden te nemen en ze op nul te zetten, zul je ontdekken dat in veel gevallen de eerste-orde voorwaarden geen oplossing hebben. Er zit een kip-en-ei-probleem in dat om je modelparameters op te lossen, je de distributie van je niet-geobserveerde gegevens moet kennen; maar de verdeling van uw niet-geobserveerde gegevens is een functie van uw modelparameters.

E-M probeert dit te omzeilen door iteratief een verdeling te raden voor de niet-geobserveerde gegevens, vervolgens de modelparameters te schatten door iets te maximaliseren dat een ondergrens is voor de werkelijke waarschijnlijkheidsfunctie, en te herhalen tot convergentie:

Het EM-algoritme

Begin met een schatting van de waarden van uw modelparameters

E-stap: gebruik voor elk datapunt met ontbrekende waarden uw modelvergelijking om de verdeling van de ontbrekende gegevens op te lossen, gegeven uw huidige schatting van de modelparameters en gegeven de waargenomen gegevens (merk op dat u een verdeling aan het oplossen bent voor elke ontbrekende waarde, niet voor de verwachte waarde). Nu we een verdeling hebben voor elke ontbrekende waarde, kunnen we de verwachtingvan de waarschijnlijkheidsfunctie berekenen met betrekking tot de niet-geobserveerde variabelen. Als onze schatting voor de modelparameter correct was, is deze verwachte waarschijnlijkheid de werkelijke waarschijnlijkheid van onze waargenomen gegevens; als de parameters niet correct waren, is het gewoon een ondergrens.

M-stap: nu we een verwachte waarschijnlijkheidsfunctie hebben zonder niet-geobserveerde variabelen erin, maximaliseert u de functie zoals u zou doen in het volledig waargenomen geval, om een ​​nieuwe schatting van uw modelparameters te krijgen.

Herhaal tot convergentie.


Antwoord 3, autoriteit 20%

Hier is een eenvoudig recept om het algoritme voor verwachtingsmaximalisatie te begrijpen:

1-Lees dit EM-zelfstudiepapierdoor Do en Batzoglou.

2-Misschien heb je vraagtekens in je hoofd, kijk eens naar de uitleg over deze wiskunde stapeluitwisseling pagina.

3-Kijk naar deze code die ik in Python heb geschreven en waarin het voorbeeld wordt uitgelegd in het EM-zelfstudiedocument van item 1:

Waarschuwing:De code kan rommelig/suboptimaal zijn, aangezien ik geen Python-ontwikkelaar ben. Maar het doet zijn werk.

import numpy as np
import math
#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 
def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     
    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])
    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood
# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45
# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)
# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50
# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B
        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            
        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)
    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 
    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

Antwoord 4, autoriteit 13%

Technisch gezien is de term “EM” een beetje ondergespecificeerd, maar ik neem aan dat je verwijst naar de Gaussian Mixture Modeling-clusteranalysetechniek, dat een exemplaaris van het algemene EM-principe.

Eigenlijk is EM-clusteranalyse geen classificatie. Ik weet dat sommige mensen clustering beschouwen als “classificatie zonder toezicht”, maar clusteranalyse is in feite iets heel anders.

Het belangrijkste verschil, en het grote misverstand over classificatie dat mensen altijd hebben met clusteranalyse, is dat: in clusteranalyse geen “juiste oplossing” is. Het is een kennis ontdekkingmethode, het is eigenlijk bedoeld om iets nieuwste vinden! Dit maakt de evaluatie erg lastig. Het wordt vaak geëvalueerd met behulp van een bekende classificatie als referentie, maar dat is niet altijd geschikt: de classificatie die u heeft, kan al dan niet overeenkomen met wat er in de gegevens staat.

Laat me je een voorbeeld geven: je hebt een grote dataset van klanten, inclusief geslachtsgegevens. Een methode die deze dataset opsplitst in “mannelijk” en “vrouwelijk” is optimaal wanneer je deze vergelijkt met de bestaande klassen. In een “voorspellende” manier van denken is dit goed, want voor nieuwe gebruikers zou je nu hun geslacht kunnen voorspellen. In een “knowledge discovery” manier van denken is dit eigenlijk slecht, omdat je een nieuwe structuurin de data wilde ontdekken. Een methode die b.v. het splitsen van de gegevens in ouderen en kinderen zou echter zo slecht mogelijkscoren met betrekking tot de mannelijke/vrouwelijke klasse. Dat zou echter een uitstekend clusterresultaat zijn (als de leeftijd niet was opgegeven).

Nu terug naar EM. In wezen gaat het ervan uit dat uw gegevens zijn samengesteld uit meerdere multivariate normale verdelingen (merk op dat dit een zeersterke aanname is, vooral wanneer u het aantal clusters vastlegt!). Vervolgens probeert het hiervoor een lokaal optimaal model te vinden door afwisselend het model en de objecttoewijzing aan het model te verbeteren.

Kies voor de beste resultaten in een classificatiecontext het aantal clusters groterdan het aantal klassen, of pas de clustering zelfs alleen toe op enkeleklassen (om erachter te komen of er enige structuur is in de klas!).

Stel dat je een classifier wilt trainen om ‘auto’s’, ‘fietsen’ en ‘vrachtwagens’ van elkaar te onderscheiden. Het heeft weinig zin om aan te nemen dat de gegevens uit precies 3 normale verdelingen bestaan. U mag er echter vanuit gaan dat er meer dan één type auto is(en vrachtwagens en fietsen). Dus in plaats van een classifier voor deze drie klassen te trainen, cluster je auto’s, vrachtwagens en fietsen in elk 10 clusters (of misschien 10 auto’s, 3 vrachtwagens en 3 fietsen, wat dan ook), en train je een classifier om deze 30 klassen uit elkaar te houden, en dan voeg het klasseresultaat weer samen met de oorspronkelijke klassen. U kunt ook ontdekken dat er één cluster is dat bijzonder moeilijk te classificeren is, bijvoorbeeld Trikes. Het zijn een beetje auto’s, en een beetje fietsen. Of bestelwagens, die meer op grote auto’s lijken dan op vrachtwagens.


Antwoord 5

Het geaccepteerde antwoord verwijst naar het Chuong EM Paper, dat een behoorlijke baan uitleggen EM. Er is ook een youtube-videowaarin de paper in meer detail wordt uitgelegd.

Om samen te vatten, hier is het scenario:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails
Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.
We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

In het geval van de vraag van de eerste proef, zouden we intuïtief denken dat B deze heeft gegenereerd, aangezien het aantal koppen goed overeenkomt met de vooringenomenheid van B… maar die waarde was slechts een gok, dus we weten het niet zeker.

Met dat in gedachten, zie ik de EM-oplossing graag als volgt:

  • Bij elke poging tot flips kan ‘stemmen’ op welke munt hij het leukst vindt
    • Dit is gebaseerd op hoe goed elke munt in zijn distributie past
    • OF, vanuit het oogpunt van de munt is er hoge verwachtingom deze proef te zien in vergelijking met de andere munt (gebaseerd op logwaarschijnlijkheid).
  • Afhankelijk van hoeveel elke proef elke munt leuk vindt, kan het de schatting van de parameter (bias) van die munt bijwerken.
    • Hoe meer een proef een munt leuk vindt, des te beter het wordt om de vooringenomenheid van de munt bij te werken om die van hemzelf weer te geven!
    • In wezen worden de vooroordelen van de munt bijgewerkt door deze gewogen updates voor alle proeven te combineren, een proces genaamd (maximalisatie), dat verwijst naar het proberen om de beste schattingen te krijgen voor de vooringenomenheid van elke munt gegeven een reeks proeven .

Dit kan een te grote vereenvoudiging zijn (of zelfs fundamenteel verkeerd op sommige niveaus), maar ik hoop dat dit helpt op een intuïtief niveau!


Antwoord 6

EM wordt gebruikt om de kans op een model Q met latente variabelen Z te maximaliseren.

Het is een iteratieve optimalisatie.

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-stap:
gegeven huidige schatting van Z bereken de verwachte loglikelihood-functie

m-stap:
vind theta die deze Q maximaliseert

AGM-voorbeeld:

e-step: schat de labeltoewijzingen voor elk datapunt op basis van de huidige gmm-parameterschatting

m-step: maximaliseer een nieuwe theta met de nieuwe labeltoewijzingen

K-means is ook een EM-algoritme en er zijn veel verklarende animaties op K-means.


Antwoord 7

Met hetzelfde artikel van Do en Batzoglou dat in het antwoord van Zhubarb wordt aangehaald, heb ik EM voor dat probleem geïmplementeerd in Java. De opmerkingen bij zijn antwoord laten zien dat het algoritme vastloopt op een lokaal optimum, wat ook gebeurt bij mijn implementatie als de parameters thetaA en thetaB hetzelfde zijn.

Hieronder staat de standaarduitvoer van mijn code, die de convergentie van de parameters laat zien.

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

Hieronder staat mijn Java-implementatie van EM om het probleem op te lossen (Do en Batzoglou, 2008). Het kernonderdeel van de implementatie is de lus om EM uit te voeren totdat de parameters convergeren.

private Parameters _parameters;
public Parameters run()
{
    while (true)
    {
        expectation();
        Parameters estimatedParameters = maximization();
        if (_parameters.converged(estimatedParameters)) {
            break;
        }
        _parameters = estimatedParameters;
    }
    return _parameters;
}

Hieronder staat de volledige code.

import java.util.*;
/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.
    double _delta = 0.00001;
    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }
    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }
        return false;
    }
    public double getThetaA()
    {
        return _thetaA;
    }
    public double getThetaB()
    {
        return _thetaB;
    }
    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }
}
/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;
    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);
            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }
    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }
    public double getNumHeads()
    {
        return _numHeads;
    }
    public double getNumTails()
    {
        return _numTails;
    }
    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }
}
/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;
    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;
    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;
    private static java.io.PrintStream o = System.out;
    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }
    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {
        while (true)
        {
            expectation();
            Parameters estimatedParameters = maximization();
            o.printf("%s\n", estimatedParameters);
            if (_parameters.converged(estimatedParameters)) {
                break;
            }
            _parameters = estimatedParameters;
        }
        return _parameters;
    }
    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {
        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();
        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();
            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());
            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());
            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;
            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).
            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;
            // Compute new expected observations for the two coins.
            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);
            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);
            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }
    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {
        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;
        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }
        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }
        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));
        //o.printf("parameters: %s\n", _parameters);
    }
    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }
    private static long nChooseK(int n, int k)
    {
        long numerator = 1;
        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }
        long denominator = factorial(k);
        return (long)(numerator / denominator);
    }
    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }
        return result;
    }
    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.
        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));
        Parameters initialParameters = new Parameters(0.6, 0.5);
        EM em = new EM(observations, initialParameters);
        Parameters finalParameters = em.run();
        o.printf("Final result:\n%s\n", finalParameters);
    }
}

Other episodes