9 lines
295 B
Python
9 lines
295 B
Python
import numpy as np
|
|
|
|
|
|
def categorical_sample(prob_n, np_random: np.random.Generator):
|
|
"""Sample from categorical distribution where each row specifies class probabilities."""
|
|
prob_n = np.asarray(prob_n)
|
|
csprob_n = np.cumsum(prob_n)
|
|
return np.argmax(csprob_n > np_random.random())
|