Source code for qmla.remote_model_learning

from __future__ import print_function  # so print doesn't show brackets

import copy
import numpy as np
import time as time

import matplotlib.pyplot as plt
import pickle
import redis

import qmla.model_for_learning
import qmla.redis_settings
import qmla.logging

pickle.HIGHEST_PROTOCOL = 4
plt.switch_backend("agg")

__all__ = ["remote_learn_model_parameters"]


[docs]def remote_learn_model_parameters( name, model_id, branch_id, exploration_rule, qmla_core_info_dict=None, remote=False, host_name="localhost", port_number=6379, qid=0, log_file="rq_output.log", ): """ Standalone function to perform Quantum Hamiltonian Learning on individual models. Used in conjunction with redis databases so this calculation can be performed without any knowledge of the QMLA instance. Given model ids and names are used to instantiate the ModelInstanceForLearning class, which is then used for learning the models parameters. QMLA info is unpickled from a redis databse, containing true operator, params etc. Once parameters are learned, we pickle the results to dictionaries held on a redis database which can be accessed by other actors. :param str name: model name string :param int model_id: unique model id :param int branch_id: QMLA branch where the model was generated :param str exploration_rule: string corresponding to a unique exploration strategy, used by get_exploration_class to generate a ExplorationStrategy (or subclass) instance. :param dict qmla_core_info_dict: crucial data for QMLA, such as number of experiments/particles etc. Default None: core info is stored on the redis database so can be retrieved there on a server; if running locally, can be passed to save pickling. :param bool remote: whether QMLA is running remotely via RQ workers. :param str host_name: name of host server on which redis database exists. :param int port_number: this QMLA instance's unique port number, on which redis database exists. :param int qid: QMLA id, unique to a single instance within a run. Used to identify the redis database corresponding to this instance. :param str log_file: Path of the log file. """ def log_print(to_print_list): qmla.logging.print_to_log( to_print_list=to_print_list, log_file=log_file, log_identifier="RemoteLearnModel {}".format(model_id), ) log_print(["Starting QHL for Model {} on branch {}".format(model_id, branch_id)]) time_start = time.time() num_redis_retries = 5 # Access databases redis_databases = qmla.redis_settings.get_redis_databases_by_qmla_id( host_name, port_number, qid ) qmla_core_info_database = redis_databases["qmla_core_info_database"] learned_models_info_db = redis_databases["learned_models_info_db"] learned_models_ids = redis_databases["learned_models_ids"] active_branches_learning_models = redis_databases["active_branches_learning_models"] any_job_failed_db = redis_databases["any_job_failed"] if qmla_core_info_dict is not None: # for local runs, qmla_core_info_dict passed, with probe_dict included # in it. probe_dict = qmla_core_info_dict["probe_dict"] else: qmla_core_info_dict = pickle.loads(qmla_core_info_database["qmla_settings"]) probe_dict = pickle.loads(qmla_core_info_database["probes_system"]) true_model_terms_matrices = qmla_core_info_dict["true_oplist"] qhl_plots = qmla_core_info_dict["qhl_plots"] plots_directory = qmla_core_info_dict["plots_directory"] long_id = qmla_core_info_dict["long_id"] # Generate model instance qml_instance = qmla.model_for_learning.ModelInstanceForLearning( model_id=model_id, model_name=name, qid=qid, log_file=log_file, exploration_rule=exploration_rule, host_name=host_name, port_number=port_number, ) try: # Learn parameters update_timer_start = time.time() qml_instance.update_model() log_print( ["Time for update alone: {}".format(time.time() - update_timer_start)] ) # Evaluate learned parameterisation # qml_instance.compute_likelihood_after_parameter_learning() except NameError: log_print( [ "Model learning failed. QHL failed for model id {}. Setting job failure model_building_utilities.".format( model_id ) ] ) any_job_failed_db.set("Status", 1) raise except BaseException: log_print( [ "Model learning failed. QHL failed for model id {}. Setting job failure model_building_utilities.".format( model_id ) ] ) any_job_failed_db.set("Status", 1) raise if qhl_plots: log_print(["Drawing plots for QHL"]) try: if len(true_model_terms_matrices) == 1: # TODO buggy qml_instance.plot_distribution_progression( save_to_file=str( plots_directory + "qhl_distribution_progression_" + str(long_id) + ".png" ) ) qml_instance.plot_distribution_progression( renormalise=False, save_to_file=str( plots_directory + "qhl_distribution_progression_uniform_" + str(long_id) + ".png" ), ) except BaseException: pass # Throw away model instance; only need to store results. updated_model_info = copy.deepcopy(qml_instance.learned_info_dict()) compressed_info = pickle.dumps(updated_model_info, protocol=4) # Store the (compressed) result set on the redis database. for k in range(num_redis_retries): try: learned_models_info_db.set(str(model_id), compressed_info) log_print( [ "learned_models_info_db added to db for model {} after {} attempts".format( str(model_id), k ) ] ) break except Exception as e: if k == num_redis_retries - 1: log_print( ["Model learning failed at the storage stage. Error: {}".format(e)] ) any_job_failed_db.set("Status", 1) pass # Update databases to record that this model has finished. for k in range(num_redis_retries): try: active_branches_learning_models.incr(int(branch_id), 1) learned_models_ids.set(str(model_id), 1) log_print( [ "Updated model/branch learned on redis db {}/{}".format( model_id, branch_id ) ] ) break except Exception as e: if k == num_redis_retries - 1: log_print(["Model learning failed to update branch info. Error: ", e]) any_job_failed_db.set("Status", 1) if remote: del updated_model_info del compressed_info del qml_instance log_print( [ "Learned model; remote time:", str(np.round((time.time() - time_start), 2)), ] ) return None else: return updated_model_info