Skip to content

Commit bbf140c

Browse files
authored
Merge pull request #12 from Davide-Ruffini/main
Negative Multinomial aggiornata
2 parents cd7b9d0 + ce43f47 commit bbf140c

3 files changed

Lines changed: 347 additions & 2 deletions

File tree

gemact/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,7 @@
5858
'pareto2': 'distributions.Pareto2',
5959
'pareto1': 'distributions.Pareto1',
6060
'uniform': 'distributions.Uniform',
61-
'multinomial': 'distributions.Multinomial'
61+
'multinomial': 'distributions.Multinomial',
62+
'dirichlet multinomial' : 'distributions.Dirichlet_Multinomial',
63+
'negative multinomial' : 'distributions.NegMultinom'
6264
}

gemact/distributions.py

Lines changed: 342 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6261,7 +6261,7 @@ def name():
62616261

62626262
@staticmethod
62636263
def category():
6264-
return {'frequency'}
6264+
return {''}
62656265

62666266

62676267
def cov(self):
@@ -6326,4 +6326,345 @@ def logpmf(self, x):
63266326
return stats.multinomial.logpmf(n=self.n, p=self.p, x=x)
63276327

63286328

6329+
6330+
class Dirichlet_Multinomial(_MultDiscreteDistribution):
6331+
"""
6332+
Dirichlet Multinomial distribution.
6333+
Wrapper to scipy dirichlet multinomial distribution (``scipy.stats.dirichlet_multinomial``).
6334+
Refer to :py:class:'~__MultDiscreteDistribution' for additional details.
6335+
6336+
6337+
:param seed: Used to set a specific seed (default=np.random.RandomState).
6338+
:type seed: int
6339+
:param \\**kwargs:
6340+
See below
6341+
6342+
:Keyword Arguments:
6343+
* *alpha* ( ``int`` or ``numpy.ndarray``) --
6344+
Concentration parameters.
6345+
* *n* (``int``) --
6346+
Number of trials.
6347+
"""
6348+
6349+
def __init__(self, seed=None, **kwargs):
6350+
_DiscreteDistribution.__init__(self)
6351+
self.n = kwargs['n']
6352+
self.alpha = kwargs['alpha']
6353+
self.seed = seed
6354+
6355+
@property
6356+
def seed(self):
6357+
return self.__seed
6358+
6359+
@seed.setter
6360+
def seed(self, value):
6361+
if value is None:
6362+
value = np.random.randint(1, 1001)
6363+
6364+
hf.assert_type_value(value, 'seed', logger, (float,int))
6365+
value = int(value)
6366+
self.__seed = value
6367+
6368+
@property
6369+
def n(self):
6370+
return self.__n
6371+
6372+
@n.setter
6373+
def n(self, value):
6374+
hf.assert_type_value(value, 'n', logger, (float, int), lower_bound=1, lower_close=True)
6375+
value = int(value)
6376+
self.__n = value
6377+
6378+
@property
6379+
def alpha(self):
6380+
return self.__alpha
6381+
6382+
@alpha.setter
6383+
def alpha(self, value):
6384+
for element in value:
6385+
hf.assert_type_value(element, 'alpha', logger, (float, int))
6386+
value = np.array(value)
6387+
self.__alpha = value
6388+
6389+
@property
6390+
def _dist(self):
6391+
return stats.dirichlet_multinomial(n=self.n, alpha=self.alpha, seed=self.seed)
6392+
6393+
@staticmethod
6394+
def name():
6395+
return 'dirichlet multinomial'
6396+
6397+
@staticmethod
6398+
def category():
6399+
return {''}
6400+
6401+
def cov(self):
6402+
"""
6403+
Covariance Matrix of a Dirichlet Multinomial Distribution.
6404+
6405+
:return: Covariance Matrix.
6406+
:rtype: ``float``
6407+
"""
6408+
6409+
return stats.dirichlet_multinomial.cov(n=self.n, alpha=self.alpha)
6410+
6411+
6412+
def var(self):
6413+
"""
6414+
Variances of a Dirichlet Multinomial Distribution.
6415+
6416+
:return: Array of Variances.
6417+
:rtype: numpy.ndarray
6418+
"""
6419+
6420+
return np.diag(stats.dirichlet_multinomial.cov(n=self.n, alpha=self.alpha))
6421+
6422+
def pmf(self, x):
6423+
"""
6424+
Probability mass function of the Dirichlet Multinomial Distribution.
6425+
6426+
:param x: quantile where probability mass function is evaluated.
6427+
:type x: ``int``
6428+
6429+
:return: probability mass function.
6430+
:rtype: ``numpy.float64`` or ``numpy.ndarray``
6431+
"""
6432+
6433+
if sum(x) != self.n:
6434+
raise ValueError("n != sum(x), i.e. one is wrong")
6435+
return stats.dirichlet_multinomial.pmf(n=self.n, alpha=self.alpha, x=x)
6436+
6437+
def logpmf(self, x):
6438+
"""
6439+
Natural logarithm of the probability mass function of the Dirichlet Multinomial Distribution.
6440+
6441+
:param x: quantile where the (natural) probability mass function logarithm is evaluated.
6442+
:type x: ``int``
6443+
:return: natural logarithm of the probability mass function
6444+
:rtype: ``numpy.float64`` or ``numpy.ndarray``
6445+
"""
6446+
6447+
if sum(x) != self.n:
6448+
raise ValueError("n != sum(x), i.e. one is wrong")
6449+
return stats.dirichlet_multinomial.logpmf(n=self.n, alpha=self.alpha, x=x)
6450+
6451+
6452+
def mean(self):
6453+
"""
6454+
Mean of the Dirichlet Multinomial Distribution.
6455+
6456+
:param x: quantile where the (natural) probability mass function logarithm is evaluated.
6457+
:type x: ``int``
6458+
:return: natural logarithm of the probability mass function
6459+
:rtype: ``numpy.float64`` or ``numpy.ndarray``
6460+
"""
6461+
return stats.dirichlet_multinomial.mean(n=self.n, alpha=self.alpha)
6462+
6463+
6464+
def rvs(self, size=1, random_state=None):
6465+
"""
6466+
Random variates generator function.
6467+
6468+
:param size: random variates sample size (default is 1).
6469+
:type size: ``int``, optional
6470+
:param random_state: random state for the random number generator.
6471+
:type random_state: ``int``, optional
6472+
:return: random variates.
6473+
:rtype: ``numpy.int`` or ``numpy.ndarray``
6474+
6475+
"""
6476+
6477+
random_state = hf.handle_random_state(random_state, logger)
6478+
np.random.seed(random_state)
6479+
hf.assert_type_value(size, 'size', logger, (float, int), lower_bound=1)
6480+
size = int(size)
6481+
alpha = self.alpha
6482+
n = self.n
6483+
if isinstance(alpha, np.ndarray) and len(alpha.shape) == 1:
6484+
alpha = np.tile(alpha, (size, 1))
6485+
n = np.full(size, n)
6486+
G = np.random.gamma(shape=alpha, scale=1.0)
6487+
prob = G / np.sum(G, axis=1, keepdims=True)
6488+
ridx = np.sum(G, axis=1) == 0
6489+
if np.any(ridx):
6490+
for i in np.where(ridx)[0]:
6491+
prob[i, :] = np.random.multinomial(1, alpha[i, :] / np.sum(alpha[i, :]), n=1).flatten()
6492+
rdm = np.array([np.random.multinomial(n[i], prob[i, :]) for i in range(size)])
6493+
return rdm
6494+
6495+
6496+
class NegMultinom(_MultDiscreteDistribution):
6497+
6498+
"""
6499+
Negative Multinomial distribution.
6500+
6501+
:param loc: Location parameter to shift the support (default=0).
6502+
:type loc: ``int``, optional
6503+
6504+
:param \\**kwargs:
6505+
See below
6506+
6507+
:Keyword Arguments:
6508+
* *x0* (``int``) --
6509+
Size parameter of the negative multinomial distribution.
6510+
* *p* (``float``) --
6511+
Probability parameter of the negative multinomial distribution.
6512+
6513+
"""
6514+
6515+
def __init__(self, x0, p, loc=0):
6516+
_DiscreteDistribution.__init__(self)
6517+
self.x0 = x0
6518+
self.p = p
6519+
self.loc = loc
6520+
6521+
@property
6522+
def x0(self):
6523+
return self.__x0
6524+
6525+
@x0.setter
6526+
def x0(self, value):
6527+
hf.assert_type_value(value, 'x0', logger, (float, int), lower_bound=0, lower_close=False)
6528+
self.__x0 = value
6529+
6530+
@property
6531+
def p(self):
6532+
return self.__p
6533+
6534+
@p.setter
6535+
def p(self, value):
6536+
value = np.array(value)
6537+
for element in value:
6538+
hf.assert_type_value(
6539+
element, 'p', logger, (float, np.floating),
6540+
lower_bound=0, upper_bound=1, lower_close=True, upper_close=True
6541+
)
6542+
6543+
if np.sum(value) >= 1:
6544+
raise ValueError("Sum of success probabilities must be less than 1")
6545+
6546+
self.__p = value
6547+
self.__p0 = 1 - np.sum(value) # Failure probability
6548+
6549+
@property
6550+
def p0(self):
6551+
"""Probability of failure (computed as 1 - sum(p))"""
6552+
return self.__p0
6553+
6554+
@staticmethod
6555+
def name():
6556+
return 'negative multinomial'
6557+
6558+
@staticmethod
6559+
def category():
6560+
return {''}
6561+
6562+
def pmf(self, x):
6563+
"""
6564+
Probability mass function.
6565+
6566+
PMF formula from reference:
6567+
Γ(∑x_i) * p0^x0 / Γ(x0) * ∏(p_i^x_i / x_i!)
6568+
6569+
:param x: Quantile where PMF is evaluated.
6570+
:type x: ``numpy.ndarray``
6571+
:return: Probability mass function evaluated at x.
6572+
:rtype: ``numpy.float64``
6573+
"""
6574+
x = np.array(x)
6575+
x0 = self.x0
6576+
6577+
gamma_term = (special.gamma(np.sum(x)+x0) / special.gamma(x0))
6578+
prob_term = (self.p0 ** x0) * np.prod((self.p ** x) / special.factorial(x))
6579+
6580+
return gamma_term * prob_term
6581+
6582+
6583+
def logpmf(self, x):
6584+
"""
6585+
Natural logarithm of the probability mass function.
6586+
6587+
:param x: Quantile where log-PMF is evaluated.
6588+
:type x: ``numpy.ndarray``
6589+
:return: Log of probability mass function evaluated at x.
6590+
:rtype: ``numpy.float64``
6591+
"""
6592+
return np.log(self.pmf(x))
6593+
6594+
def mean(self):
6595+
"""
6596+
Mean vector of the distribution.
6597+
6598+
:return: Mean vector.
6599+
:rtype: ``numpy.ndarray``
6600+
"""
6601+
return (self.x0 / self.p0) * self.p
6602+
6603+
def var(self):
6604+
"""
6605+
Variances of a Negative Multinomial Distribution.
6606+
6607+
:return: Array of Variances.
6608+
:rtype: ``numpy.ndarray``
6609+
"""
6610+
return (self.x0 / self.p0**2) * self.p**2 + (self.x0 / self.p0) * self.p
6611+
6612+
def cov(self):
6613+
"""
6614+
Covariance matrix of a Negative Multinomial Distribution.
6615+
6616+
:return: Covariance matrix.
6617+
:rtype: ``numpy.ndarray``
6618+
"""
6619+
p = self.p
6620+
x0 = self.x0
6621+
p0 = self.p0
6622+
6623+
# Diagonal terms
6624+
diag = (x0 / p0**2) * p**2 + (x0 / p0) * p
6625+
6626+
# Off-diagonal terms
6627+
off_diag = (x0 / p0**2) * np.outer(p, p)
6628+
6629+
# Create full covariance matrix
6630+
cov_matrix = np.diag(diag) + off_diag - np.diag(np.diag(off_diag))
6631+
6632+
return cov_matrix
6633+
6634+
def rvs(self, size=1, random_state=None):
6635+
"""
6636+
Random variates generator function.
6637+
6638+
:param size: Number of random variates to generate (default=1).
6639+
:type size: ``int``
6640+
:param random_state: Random state for reproducibility.
6641+
:type random_state: ``int``, optional
6642+
:return: Random variates.
6643+
:rtype: ``numpy.ndarray``
6644+
"""
6645+
random_state = hf.handle_random_state(random_state, logger)
6646+
np.random.seed(random_state)
6647+
6648+
samples = []
6649+
for _ in range(size):
6650+
total = np.random.negative_binomial(self.x0, self.p0)
6651+
if total > 0:
6652+
counts = np.random.multinomial(total, self.p / (1 - self.p0))
6653+
else:
6654+
counts = np.zeros_like(self.p)
6655+
samples.append(counts)
6656+
6657+
return np.array(samples)
6658+
6659+
def mgf(self, t):
6660+
"""
6661+
Moment generating function.
6662+
6663+
:param t: Vector where MGF is evaluated.
6664+
:type t: ``numpy.ndarray``
6665+
:return: Moment generating function evaluated at t.
6666+
:rtype: ``numpy.float64``
6667+
"""
6668+
exponent = np.sum(self.p * np.exp(t))
6669+
return (self.p0 / (1 - exponent)) ** self.x0
63296670

gemact/libraries.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
import timeit
1414
from twiggy import quick_setup, log
1515
from itertools import groupby
16+
from scipy.special import gammaln
17+
import math as mt

0 commit comments

Comments
 (0)