Source code for qmla.model_for_comparison

import numpy as np
import scipy as sp
import os
import time
import itertools
import copy

import qinfer as qi
import redis
import pickle

import qmla.redis_settings
import qmla.logging
import qmla.get_exploration_strategy
import qmla.model_building_utilities

pickle.HIGHEST_PROTOCOL = 4

__all__ = ["ModelInstanceForComparison"]


[docs]class ModelInstanceForComparison: """ Model instances used for Bayes factor comparisons. When Bayes factors are calculated remotely (ie on RQ workers), they require infrastructure to do calculations, e.g. QInfer SMCUpdater instances. This class captures the minimum required to enable these calculations. After learning, important data from :class:`~qmla.ModelInstanceForLearning` is stored on the redis database. This class unpickles the useful information and generates new instances of the updater etc. to use in the comparison calculations. If run locally, `qmla_core_info_database` and `learned_model_info` can be passed directly to this class, to save unpickling data from the redis database. :param int model_id: ID of the model to study :param qid: ID of the QMLA instance :param dict qmla_core_info_database: essential details about the QMLA instance needed to learn/compare models. If None, this is retrieved instead from the redis database. :param dict learned_model_info: result of learning, generated by :meth:`~qmla.ModelInstanceForLearning.learned_info_dict`. :param str host_name: name of host server on which redis database exists. :param int port_number: port number unique to this QMLA instance on redis database :param str log_file: path of QMLA instance's log file. """ def __init__( self, model_id, qid, opponent, qmla_core_info_database=None, learned_model_info=None, host_name="localhost", port_number=6379, log_file="QMD_log.log", ): self.log_file = log_file self.qmla_id = qid self.model_id = model_id self.opponent = int(opponent) # Get essential data if qmla_core_info_database is None: redis_databases = qmla.redis_settings.get_redis_databases_by_qmla_id( host_name, port_number, qid ) qmla_core_info_database = redis_databases["qmla_core_info_database"] qmla_core_info_dict = pickle.loads( qmla_core_info_database.get("qmla_settings") ) self.probes_system = pickle.loads(qmla_core_info_database["probes_system"]) self.probes_simulator = pickle.loads( qmla_core_info_database["probes_simulator"] ) else: qmla_core_info_dict = qmla_core_info_database.get("qmla_settings") self.probes_system = qmla_core_info_database["probes_system"] self.probes_simulator = qmla_core_info_database["probes_simulator"] self.plot_probes = pickle.load( open(qmla_core_info_dict["probes_plot_file"], "rb") ) self.plots_directory = qmla_core_info_dict["plots_directory"] self.debug_mode = qmla_core_info_dict["debug_mode"] self.plot_level = qmla_core_info_dict["plot_level"] self.figure_format = qmla_core_info_dict["figure_format"] # Assign attributes based on core data self.num_experiments = qmla_core_info_dict["num_experiments"] self.num_particles = qmla_core_info_dict["num_particles"] self.probe_number = qmla_core_info_dict["num_probes"] self.true_model_constituent_operators = qmla_core_info_dict["true_oplist"] self.true_model_params = qmla_core_info_dict["true_model_terms_params"] self.true_model_name = qmla_core_info_dict["true_name"] self.true_param_dict = qmla_core_info_dict["true_param_dict"] self.true_model_constructor = qmla_core_info_dict["true_model_constructor"] self.experimental_measurements = qmla_core_info_dict[ "experimental_measurements" ] self.experimental_measurement_times = qmla_core_info_dict[ "experimental_measurement_times" ] self.results_directory = qmla_core_info_dict["results_directory"] if learned_model_info is None: # Get data specific to this model, learned elsewhere and stored on # redis database try: redis_databases = qmla.redis_settings.get_redis_databases_by_qmla_id( host_name, port_number, qid ) learned_models_info_db = redis_databases["learned_models_info_db"] except BaseException: print("Unable to retrieve redis database.") raise model_id_str = str(float(model_id)) try: learned_model_info = pickle.loads( learned_models_info_db.get(model_id_str), encoding="latin1" ) except BaseException: try: learned_model_info = pickle.loads( learned_models_info_db.get(model_id_str) ) except: self.log_print(["Failed to unload model data for comparison"]) # Assign parameters from model learned info, retrieved from database self.model_name = learned_model_info["name"] self.times_learned_over = learned_model_info["times_learned_over"] self.final_learned_params = learned_model_info["final_learned_params"] self.exploration_strategy_of_this_model = learned_model_info[ "exploration_strategy_of_this_model" ] self.posterior_marginal = learned_model_info["posterior_marginal"] self.model_normalization_record = learned_model_info[ "model_normalization_record" ] self.log_total_likelihood = learned_model_info["log_total_likelihood"] self.estimated_mean_params = learned_model_info["estimated_mean_params"] self.qhl_final_param_estimates = learned_model_info["qhl_final_param_estimates"] self.qhl_final_param_uncertainties = learned_model_info[ "qhl_final_param_uncertainties" ] self.covariance_mtx_final = learned_model_info["covariance_mtx_final"] self.expectation_values = learned_model_info["expectation_values"] self.learned_hamiltonian = learned_model_info["learned_hamiltonian"] self.track_experiment_parameters = learned_model_info[ "track_experiment_parameters" ] self.log_print(["Track exp params eg:", self.track_experiment_parameters[0]]) # Process data from learned info if self.model_name == self.true_model_name: self.is_true_model = True self.log_print(["This is the true model for comparison."]) else: self.is_true_model = False self.exploration_class = qmla.get_exploration_strategy.get_exploration_class( exploration_rules=self.exploration_strategy_of_this_model, log_file=self.log_file, qmla_id=self.qmla_id, ) self.model_constructor = self.exploration_class.model_constructor( name=self.model_name ) self.model_name_latex = self.model_constructor.name_latex self.model_terms_matrices = self.model_constructor.terms_matrices self.model_terms_parameters_final = np.array(self.final_learned_params) # self.model_name_latex = self.exploration_class.latex_name(self.model_name) # New instances of model and updater used by QInfer self.log_print(["Getting QInfer model"]) self.qinfer_model = self.exploration_class.get_qinfer_model( model_name=self.model_name, model_constructor=self.model_constructor, true_model_constructor=self.true_model_constructor, num_probes=self.probe_number, probes_system=self.probes_system, probes_simulator=self.probes_simulator, exploration_rules=self.exploration_strategy_of_this_model, experimental_measurements=self.experimental_measurements, experimental_measurement_times=self.experimental_measurement_times, qmla_id=self.qmla_id, log_file=self.log_file, debug_mode=self.debug_mode, ) # Reconstruct the updater from results of learning self.reconstruct_updater = True # optionally just load it if self.reconstruct_updater: try: # TODO this can cause problems - some models have singular cov mt - WHY? posterior_distribution = qi.MultivariateNormalDistribution( self.estimated_mean_params, self.covariance_mtx_final ) except: self.log_print( [ "cov mtx is singular in trying to reconstruct SMC updater.\n", self.covariance_mtx_final, ] ) raise num_particles_for_bf = max( 5, int( self.exploration_class.fraction_particles_for_bf * self.num_particles ), ) # this allows the exploration strategy to use less particles for the comparison stage self.qinfer_updater = qi.SMCUpdater( model=self.qinfer_model, n_particles=num_particles_for_bf, prior=posterior_distribution, resample_thresh=self.exploration_class.qinfer_resampler_threshold, resampler=qi.LiuWestResampler( a=self.exploration_class.qinfer_resampler_a ), ) self.qinfer_updater._normalization_record = self.model_normalization_record else: # Optionally pickle the entire updater # (first include updater in ModelInstanceForLearning.learned_info_dict()) self.qinfer_updater = pickle.loads(learned_model_info["updater"]) # Fresh experiment design heuristic self.experiment_design_heuristic = self.exploration_class.get_heuristic( model_id=self.model_id, updater=self.qinfer_updater, oplist=self.model_terms_matrices, num_experiments=self.num_experiments, num_probes=self.probe_number, log_file=self.log_file, inv_field=[item[0] for item in self.qinfer_model.expparams_dtype[1:]], max_time_to_enforce=self.exploration_class.max_time_to_consider, figure_format=self.figure_format, ) # Delete extra data now that everything useful is extracted del qmla_core_info_dict, learned_model_info self.log_print(["Instantiated."]) ########## # Section: update for Bayes factor ##########
[docs] def update_log_likelihood( self, new_times, new_experimental_params, ): r""" """ # Reduced normalization record using only experiments to consider experiment_id_to_keep = int( len(self.qinfer_updater.normalization_record) - ( self.exploration_class.fraction_own_experiments_for_bf * len(self.qinfer_updater.normalization_record) ) ) self.qinfer_updater._normalization_record = ( self.qinfer_updater._normalization_record[experiment_id_to_keep:] ) self.bf_times = self.times_learned_over[experiment_id_to_keep:] # List of opponent's times, possibly shortened experiment_id_to_keep = int( len(new_times) - ( self.exploration_class.fraction_opponents_experiments_for_bf * len(new_times) ) ) epoch_id = len(self.times_learned_over) experiments_to_update_with = new_experimental_params[experiment_id_to_keep:] self.log_print(["Times to update length:", len(experiments_to_update_with)]) for experiment in experiments_to_update_with: # sample from own updater/heuristic so particle is correct shape experiment_for_update = self.experiment_design_heuristic(epoch_id=epoch_id) # retrieve probe and time used by opponent experiment_for_update["probe_id"] = experiment["probe_id"][0] exp_time = experiment["t"][0] experiment_for_update["t"] = exp_time self.bf_times.append(experiment["t"]) # run experiment params_array = np.array([[self.true_model_params[:]]]) self.log_print_debug(["BF update epoch ", epoch_id]) datum = self.qinfer_model.simulate_experiment( params_array, experiment_for_update, repeat=1 ) # update qinfer self.qinfer_updater.update(datum, experiment_for_update) epoch_id += 1 self.log_print(["BF times:", self.bf_times]) self.bf_times = qmla.utilities.flatten(self.bf_times) return self.qinfer_updater.log_total_likelihood
########## # Section: Plotting ##########
[docs] def plot_dynamics(self, ax, times): r""" Plot dynamics of this model after its parameter learning stage. :param ax: matplotlib axis to plot on :param list times: times against which to plot """ times_not_yet_computed = list(set(times) - set(self.expectation_values.keys())) n_qubits = self.model_constructor.num_qubits plot_probe = self.plot_probes[n_qubits] for t in times_not_yet_computed: self.expectation_values[t] = self.exploration_class.get_expectation_value( ham=self.learned_hamiltonian, t=t, state=plot_probe # TODO, # TODO ) l = ax.plot( times, [self.expectation_values[t] for t in times], label="{}".format(self.model_id), ) return l
########## # Section: Utilities ##########
[docs] def log_print( self, to_print_list, log_identifier=None, ): r"""Wrapper for :func:`~qmla.print_to_log`""" if log_identifier is None: log_identifier = "ModelForComparison {} (vs {})".format( int(self.model_id), self.opponent ) qmla.logging.print_to_log( to_print_list=to_print_list, log_file=self.log_file, log_identifier=log_identifier, )
[docs] def log_print_debug(self, to_print_list): r"""Log print if global debug_log_print set to True.""" if self.debug_mode: self.log_print( to_print_list=to_print_list, log_identifier="Debug Comparison Model {}".format(self.model_id), )