Source code for krotov.info_hooks

"""Routines that can be passed as `info_hook` to :func:`.optimize_pulses`"""
import sys
import time

import numpy as np
from uniseg.graphemecluster import grapheme_clusters


__all__ = ['chain', 'print_debug_information', 'print_table']


def _qobj_nbytes(qobj):
    """Return an estimate for the number of bytes of memory required to store a
    Qobj."""
    return (
        qobj.data.data.nbytes
        + qobj.data.indptr.nbytes
        + qobj.data.indices.nbytes
        + sys.getsizeof(qobj)
        + sum(sys.getsizeof(attr) for attr in qobj.__dict__)
    )


[docs]def chain(*hooks): """Chain multiple `info_hook` or `modify_params_after_iter` callables together. Example: >>> def print_fidelity(**kwargs): ... F_re = np.average(np.array(kwargs['tau_vals']).real) ... print(" F = %f" % F_re) >>> info_hook = chain(print_debug_information, print_fidelity) Note: Functions that are connected via :func:`chain` may use the `shared_data` share the same `shared_data` argument, which they can use to communicate down the chain. """ def info_hook(**kwargs): result = [] for hook in hooks: res = hook(**kwargs) if res is not None: result.append(res) if len(result) > 0: if len(result) == 1: return result[0] else: return tuple(result) else: return None return info_hook
def _grapheme_len(text, fail_with_zero=False): """Number of graphemes in `text`. This is the length of the `text` when printed:: >>> s = 'Â' >>> len(s) 2 >>> _grapheme_len(s) 1 If `fail_with_zero` is given a True, return 0 if `text` is not a string, instead of throwing a TypeError. """ try: return len(list(grapheme_clusters(text))) except TypeError: if fail_with_zero: return 0 raise def _rjust(text, width, fillchar=' '): """Right-justify text for a total of `width` graphemes. The `width` is based on graphemes:: >>> s = 'Â' >>> s.rjust(2) 'Â' >>> _rjust(s, 2) ' Â' """ len_text = _grapheme_len(text) return fillchar * (width - len_text) + text class _GalHdrUnicodeSubscript: """Label for "show_g_a_int_per_pulse" columns, with unicode subscript. Example:: >>> print(_GalHdrUnicodeSubscript().format(l=123)) ∫gₐ(ϵ₁₂₃)dt """ def format(self, l): """Return a string label for l.""" digits = str(int(l)) as_subscript = lambda d: chr(ord(d) + ord('₁') - ord('1')) return "∫gₐ(ϵ" + ''.join([as_subscript(d) for d in digits]) + ")dt" def _pulse_range(pulse): """Return the range information about the given pulse value array Example:: >>> _pulse_range(np.array([-1, 1, 5])) '[-1.00, 5.00]' >>> _pulse_range(np.array([-1, 1+5j, 5])) '[(r:-1.00, i:0.00), (r:5.00, i:5.00)]' """ if np.iscomplexobj(pulse): pulse_real = pulse.real pulse_imag = pulse.imag r0, r1 = np.min(pulse_real), np.max(pulse_real) i0, i1 = np.min(pulse_imag), np.max(pulse_imag) return "[(r:%.2f, i:%.2f), (r:%.2f, i:%.2f)]" % (r0, i0, r1, i1) else: return "[%.2f, %.2f]" % (np.min(pulse), np.max(pulse))