Skip to content

Commit a944a89

Browse files
committed
change GENETIC to Genetic
1 parent 4dccfd5 commit a944a89

File tree

3 files changed

+49
-28
lines changed

3 files changed

+49
-28
lines changed

doc/guide/guide-control.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ have passed without improvement (stagnation criterion).
351351
Each generation represents a full evaluation of the population, making the method inherently parallelizable
352352
and effective in high-dimensional control landscapes.
353353
354-
In QuTiP, the GA optimization is implemented via the ``_GENETIC`` class, and can be invoked using the
355-
standard ``optimize_pulses`` interface by setting the algorithm to ``"GENETIC"``.
354+
In QuTiP, the GA optimization is implemented via the ``_Genetic`` class, and can be invoked using the
355+
standard ``optimize_pulses`` interface by setting the algorithm to ``"Genetic"``.
356356
357357
Optimal Quantum Control in QuTiP (Genetic Algorithm)
358358
====================================================
@@ -388,7 +388,7 @@ Defining a control problem in QuTiP using the Genetic Algorithm follows the same
388388
389389
# set genetic algorithm hyperparameters
390390
algorithm_kwargs = {
391-
"alg": "GENETIC",
391+
"alg": "Genetic",
392392
"population_size": 50,
393393
"generations": 100,
394394
"mutation_rate": 0.2,

src/qutip_qoc/_genetic.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
import time
44
from qutip_qoc.result import Result
55

6-
class _GENETIC:
6+
7+
class _Genetic:
78
"""
89
Genetic Algorithm-based optimizer for quantum control problems.
910
1011
This class implements a global optimization routine using a Genetic Algorithm
11-
to optimize control pulses that drive a quantum system from an initial state
12-
to a target state (or unitary).
12+
to optimize control pulses that drive a quantum system from an initial state
13+
to a target state (or unitary).
14+
15+
Based on the code from Jonathan Brown
1316
"""
1417

1518
def __init__(
@@ -26,7 +29,9 @@ def __init__(
2629
):
2730
self._objective = objectives[0]
2831
self._Hd = self._objective.H[0]
29-
self._Hc_lst = [H[0] if isinstance(H, list) else H for H in self._objective.H[1:]]
32+
self._Hc_lst = [
33+
H[0] if isinstance(H, list) else H for H in self._objective.H[1:]
34+
]
3035
self._initial = self._objective.initial
3136
self._target = self._objective.target
3237
self._norm_fac = 1 / self._target.norm()
@@ -41,14 +46,17 @@ def __init__(
4146
self.generations = alg_kwargs.get("generations", 100)
4247
self.mutation_rate = alg_kwargs.get("mutation_rate", 0.3)
4348
self.fid_err_targ = alg_kwargs.get("fid_err_targ", 1e-4)
44-
self._stagnation_patience = 20 # Internally fixed
49+
self._stagnation_patience = 50 # Internally fixed
4550

4651
self._integrator_kwargs = integrator_kwargs
4752
self._fid_type = alg_kwargs.get("fid_type", "PSU")
4853

4954
self._generator = self._prepare_generator()
50-
self._solver = qt.MESolver(H=self._generator, options=self._integrator_kwargs) \
51-
if self._Hd.issuper else qt.SESolver(H=self._generator, options=self._integrator_kwargs)
55+
self._solver = (
56+
qt.MESolver(H=self._generator, options=self._integrator_kwargs)
57+
if self._Hd.issuper
58+
else qt.SESolver(H=self._generator, options=self._integrator_kwargs)
59+
)
5260

5361
self._result = Result(
5462
objectives=[self._objective],
@@ -64,10 +72,18 @@ def __init__(
6472
self._result._final_states = []
6573

6674
def _prepare_generator(self):
67-
args = {f"p{i+1}_{j}": 0.0 for i in range(self.N_controls) for j in range(self.N_steps)}
75+
args = {
76+
f"p{i+1}_{j}": 0.0
77+
for i in range(self.N_controls)
78+
for j in range(self.N_steps)
79+
}
6880

6981
def make_coeff(i, j):
70-
return lambda t, args: args[f"p{i+1}_{j}"] if int(t / (self._evo_time / self.N_steps)) == j else 0
82+
return lambda t, args: (
83+
args[f"p{i+1}_{j}"]
84+
if int(t / (self._evo_time / self.N_steps)) == j
85+
else 0
86+
)
7187

7288
H_qev = [self._Hd]
7389
for i, Hc in enumerate(self._Hc_lst):
@@ -77,7 +93,11 @@ def make_coeff(i, j):
7793
return qt.QobjEvo(H_qev, args=args)
7894

7995
def _infid(self, params):
80-
args = {f"p{i+1}_{j}": params[i * self.N_steps + j] for i in range(self.N_controls) for j in range(self.N_steps)}
96+
args = {
97+
f"p{i+1}_{j}": params[i * self.N_steps + j]
98+
for i in range(self.N_controls)
99+
for j in range(self.N_steps)
100+
}
81101
result = self._solver.run(self._initial, [0.0, self._evo_time], args=args)
82102
final_state = result.final_state
83103
self._result._final_states.append(final_state)
@@ -87,7 +107,9 @@ def _infid(self, params):
87107
fid = 0.5 * np.real((diff.dag() * diff).tr())
88108
else:
89109
overlap = self._norm_fac * self._target.overlap(final_state)
90-
fid = 1 - np.abs(overlap) if self._fid_type == "PSU" else 1 - np.real(overlap)
110+
fid = (
111+
1 - np.abs(overlap) if self._fid_type == "PSU" else 1 - np.real(overlap)
112+
)
91113

92114
return fid
93115

@@ -101,7 +123,7 @@ def initial_population(self):
101123
return np.random.uniform(-1, 1, (self.N_pop, self.N_var))
102124

103125
def darwin(self, population, fitness):
104-
indices = np.argsort(-fitness)[:self.N_pop // 2]
126+
indices = np.argsort(-fitness)[: self.N_pop // 2]
105127
return population[indices], fitness[indices]
106128

107129
def pairing(self, survivors, survivor_fitness):
@@ -130,7 +152,9 @@ def build_next_gen(self, survivors, offspring):
130152
return np.vstack((survivors, offspring))
131153

132154
def mutate(self, population):
133-
n_mut = int((population.shape[0] - 1) * population.shape[1] * self.mutation_rate)
155+
n_mut = int(
156+
(population.shape[0] - 1) * population.shape[1] * self.mutation_rate
157+
)
134158
row = np.random.randint(1, population.shape[0], size=n_mut)
135159
col = np.random.randint(0, population.shape[1], size=n_mut)
136160
population[row, col] += np.random.normal(0, 0.3, size=n_mut)
@@ -183,13 +207,15 @@ def optimize(self):
183207

184208
self._result.message = (
185209
f"Stopped early: reached infidelity target {self.fid_err_targ}"
186-
if -best_fit <= self.fid_err_targ else
187-
f"Stopped due to stagnation after {self._stagnation_patience} generations"
188-
if no_improvement_counter >= self._stagnation_patience else
189-
"Optimization completed successfully"
210+
if -best_fit <= self.fid_err_targ
211+
else (
212+
f"Stopped due to stagnation after {self._stagnation_patience} generations"
213+
if no_improvement_counter >= self._stagnation_patience
214+
else "Optimization completed successfully"
215+
)
190216
)
191217
return self._result
192-
218+
193219
def result(self):
194220
self._result.start_local_time = time.strftime(
195221
"%Y-%m-%d %H:%M:%S", time.localtime(self._result.start_local_time)

tests/test_result.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ def sin_z_jax(t, r, **kwargs):
200200

201201
# --------------------------- Genetic ---------------------------
202202

203-
# TODO: this is the input for optimiz_pulses() function
204-
# you can use this routine to test your implementation
205203

206204
# state to state transfer
207205
init = qt.basis(2, 0)
@@ -219,16 +217,15 @@ def sin_z_jax(t, r, **kwargs):
219217
control_parameters={
220218
"p": {"bounds": [(-13, 13)]},
221219
},
222-
tlist=np.linspace(0, 10, 100), # TODO: derive single step duration and max evo time / max num steps from this
220+
tlist=np.linspace(0, 10, 100),
223221
algorithm_kwargs={
224222
"fid_err_targ": 0.01,
225-
"alg": "GENETIC",
223+
"alg": "Genetic",
226224
"max_iter": 100,
227225
},
228226
optimizer_kwargs={},
229227
)
230228

231-
# TODO: no big difference for unitary evolution
232229

233230
initial = qt.qeye(2) # Identity
234231
target = qt.gates.hadamard_transform()
@@ -237,8 +234,6 @@ def sin_z_jax(t, r, **kwargs):
237234
objectives=[Objective(initial, H, target)],
238235
)
239236

240-
241-
242237
@pytest.fixture(
243238
params=[
244239
pytest.param(state2state_grape, id="State to state (GRAPE)"),

0 commit comments

Comments
 (0)