Source code for wofry.propagator.propagator

"""
Core wofry propagator infrastructure: Propagator, PropagationManager, and PropagationParameters.
"""
from srxraylib.util.threading import Singleton, synchronized_method

from syned.beamline.beamline_element import BeamlineElement

from wofry.propagator.wavefront  import Wavefront, WavefrontDimension

[docs]class PropagationElements(object): INSERT_AFTER = 0 INSERT_BEFORE = 1 def __init__(self): self.__propagation_elements = [] self.__propagation_element_parameters = []
[docs] def add_beamline_element(self, beamline_element=BeamlineElement(), element_parameters=None): if beamline_element is None: raise ValueError("Beamline is None") self.__propagation_elements.append(beamline_element) self.__propagation_element_parameters.append(element_parameters)
[docs] def add_beamline_elements(self, beamline_elements=[], element_parameters_list=None): if beamline_elements is None: raise ValueError("Beamline is None") if not element_parameters_list is None: if len(beamline_elements) != len(element_parameters_list): raise ValueError("Specific Parameters list does not match Beamline Elements list") else: element_parameters_list = [None]*len(beamline_elements) for beamline_element, element_parameters in zip(beamline_elements, element_parameters_list): self.add_beamline_element(beamline_element, element_parameters)
[docs] def insert_beamline_element(self, index, new_element=BeamlineElement(), mode=INSERT_BEFORE, new_element_parameters=None): if mode == PropagationElements.INSERT_BEFORE: if index == 0: self.__propagation_elements = [new_element] + self.__propagation_elements self.__propagation_element_parameters = [new_element_parameters] else: self.__propagation_elements.insert(index, new_element) self.__propagation_element_parameters.insert(index, new_element_parameters) elif mode == PropagationElements.INSERT_AFTER: if index == len(self.__propagation_elements) - 1: self.__propagation_elements = self.__propagation_elements + [new_element] self.__propagation_element_parameters = self.__propagation_element_parameters + [new_element_parameters] else: self.__propagation_elements.insert(index+1, new_element) self.__propagation_element_parameters.insert(index+1, new_element_parameters)
[docs] def get_propagation_elements_number(self): return len(self.__propagation_elements)
[docs] def get_propagation_elements(self): return self.__propagation_elements
[docs] def get_propagation_element(self, index): return self.__propagation_elements[index]
[docs] def get_propagation_elements_parameters(self): return self.__propagation_element_parameters
[docs] def get_propagation_element_parameter(self, index): return self.__propagation_element_parameters[index]
[docs]class PropagationParameters(object): def __init__(self, wavefront = Wavefront(), propagation_elements = PropagationElements(), **additional_parameters): self._wavefront = wavefront self._propagation_elements = propagation_elements self._additional_parameters = additional_parameters
[docs] def get_wavefront(self): return self._wavefront
[docs] def get_PropagationElements(self): return self._propagation_elements
[docs] def set_additional_parameters(self, key, value): if self._additional_parameters is None: self._additional_parameters = {key : value} else: self._additional_parameters[key] = value
[docs] def get_additional_parameter(self, key): return self._additional_parameters[key]
[docs] def has_additional_parameter(self, key): return key in self._additional_parameters
[docs]class AbstractPropagator(object): def __init__(self): super().__init__()
[docs] def get_dimension(self): raise NotImplementedError("This method is abstract")
[docs] def get_handler_name(self): raise NotImplementedError("This method is abstract")
[docs] def is_handler(self, handler_name): return handler_name == self.get_handler_name()
[docs] def do_propagation(self, parameters=PropagationParameters()): raise NotImplementedError("This method is abstract" + "\n\naccepts " + PropagationParameters.__module__ + "." + PropagationParameters.__name__ + "\nreturns " + Wavefront.__module__ + "." + Wavefront.__name__)
[docs]class PropagationMode: STEP_BY_STEP = 0 WHOLE_BEAMLINE = 1
[docs]class PropagationApplication: ALL = "All"
@Singleton class PropagationManager(object): def __init__(self): self.__chains_hashmap = {WavefrontDimension.ONE : [], WavefrontDimension.TWO : []} self.__propagation_mode_hashmap = {PropagationApplication.ALL : PropagationMode.STEP_BY_STEP} self.__is_initialized_hashmap = {PropagationApplication.ALL : False} @synchronized_method def set_initialized(self, application = PropagationApplication.ALL, initialized=True): self.__is_initialized_hashmap[application] = initialized @synchronized_method def is_initialized(self, application = PropagationApplication.ALL): if application in self.__is_initialized_hashmap.keys(): return self.__is_initialized_hashmap[application] else: return False @synchronized_method def set_propagation_mode(self, application = PropagationApplication.ALL, mode=PropagationMode.STEP_BY_STEP): self.__propagation_mode_hashmap[application] = mode @synchronized_method def get_propagation_mode(self, application = PropagationApplication.ALL): return self.__propagation_mode_hashmap[application] @synchronized_method def add_propagator(self, propagator=AbstractPropagator()): if propagator is None: raise ValueError("Given propagator is None") if not isinstance(propagator, AbstractPropagator): raise ValueError("Given propagator is not a compatible object") dimension = propagator.get_dimension() if not (dimension == WavefrontDimension.ONE or dimension == WavefrontDimension.TWO): raise ValueError("Wrong propagator dimension") propagation_chain_of_responsibility = self.__chains_hashmap.get(dimension) for existing in propagation_chain_of_responsibility: if existing.is_handler(propagator.get_handler_name()): raise ValueError(f"Propagator {propagator.get_handler_name()} already in the Chain") propagation_chain_of_responsibility.append(propagator) @synchronized_method def has_propagator(self, handler_name="<Propagator Name>", dimension=WavefrontDimension.ONE): propagation_chain_of_responsibility = self.__chains_hashmap.get(dimension) for existing in propagation_chain_of_responsibility: if existing.get_handler_name() == handler_name: return True return False @synchronized_method def get_propagators_number(self, dimension=None): propagation_chain_of_responsibility_1D = self.__chains_hashmap.get(WavefrontDimension.ONE) propagation_chain_of_responsibility_2D = self.__chains_hashmap.get(WavefrontDimension.TWO) if dimension == None: return len(propagation_chain_of_responsibility_1D), len(propagation_chain_of_responsibility_2D) elif dimension == WavefrontDimension.ONE: return len(propagation_chain_of_responsibility_1D) elif dimension == WavefrontDimension.TWO: return len(propagation_chain_of_responsibility_2D) else: raise ValueError("Dimension not valid " + str(dimension)) def do_propagation(self, propagation_parameters, handler_name): for propagator in self.__chains_hashmap.get(propagation_parameters.get_wavefront().get_dimension()): if propagator.is_handler(handler_name): return propagator.do_propagation(parameters=propagation_parameters) raise Exception("Handler not found: "+handler_name) # ---------------------------------------------------------------
[docs]class Propagator(AbstractPropagator):
[docs] def do_propagation(self, parameters=PropagationParameters()): wavefront = parameters.get_wavefront() for index in range(0, parameters.get_PropagationElements().get_propagation_elements_number()): element = parameters.get_PropagationElements().get_propagation_element(index) coordinates = element.get_coordinates() if coordinates.p() != 0.0: wavefront = self.do_specific_progation_before(wavefront, coordinates.p(), parameters, element_index=index) wavefront = element.get_optical_element().applyOpticalElement(wavefront, parameters, element_index=index) if coordinates.q() != 0.0: wavefront = self.do_specific_progation_after(wavefront, coordinates.q(), parameters, element_index=index) return wavefront
[docs] def do_specific_progation_before(self, wavefront, propagation_distance, parameters, element_index=None): raise NotImplementedError("This method is abstract")
[docs] def do_specific_progation_after(self, wavefront, propagation_distance, parameters, element_index=None): raise NotImplementedError("This method is abstract")
[docs] def get_additional_parameter(self, parameter_name, default_value, propagation_parameters, element_index=None, ): value = default_value try: value = propagation_parameters.get_additional_parameter(parameter_name) except: pass if element_index is None: myindex = 0 else: myindex = element_index parameters_dict = propagation_parameters.get_PropagationElements().get_propagation_element_parameter(myindex) if parameters_dict is not None: if parameter_name in parameters_dict.keys(): value = parameters_dict[parameter_name] return value
[docs]class Propagator1D(Propagator):
[docs] def get_dimension(self): return WavefrontDimension.ONE
[docs] def do_propagation(self, parameters=PropagationParameters()): if not parameters.get_wavefront().get_dimension() == WavefrontDimension.ONE: raise Exception("wrong wavefront! it is not 1D") return super().do_propagation(parameters)
[docs]class Propagator2D(Propagator):
[docs] def get_dimension(self): return WavefrontDimension.TWO
[docs] def do_propagation(self, parameters=PropagationParameters()): if not parameters.get_wavefront().get_dimension() == WavefrontDimension.TWO: raise Exception("wrong wavefront! it is not 2D") return super().do_propagation(parameters)