diff --git a/src/pyFAI/geometry/core.py b/src/pyFAI/geometry/core.py index 6232d84fd..fe25b136b 100644 --- a/src/pyFAI/geometry/core.py +++ b/src/pyFAI/geometry/core.py @@ -124,7 +124,10 @@ class Geometry(object): "chiDiscAtPi", "_wavelength", "_dssa_order", '_oversampling', '_correct_solid_angle_for_spline', '_transmission_normal') - + PROMOTION = {"AzimuthalIntegrator": "pyFAI.integrator.azimuthal.AzimuthalIntegrator", + "FiberIntegrator": "pyFAI.integrator.fiber.FiberIntegrator", + "GeometryRefinement": "pyFAI.geometryRefinement.GeometryRefinement", + "Geometry": "pyFAI.geometry.core.Geometry"} def __init__(self, dist=1, poni1=0, poni2=0, rot1=0, rot2=0, rot3=0, pixel1=None, pixel2=None, splineFile=None, detector=None, wavelength=None, @@ -2088,7 +2091,7 @@ def calcfrom2d(self, I, tth, chi, shape=None, mask=None, calcimage[numpy.where(mask)] = dummy return calcimage - def promote(self, type_="pyFAI.azimuthalIntegrator.AzimuthalIntegrator", kwargs=None): + def promote(self, type_="pyFAI.integrator.azimuthal.AzimuthalIntegrator", kwargs=None): """Promote this instance into one of its derived class (deep copy like) :param type_: Fully qualified name of the class to promote to, or the class itself @@ -2099,6 +2102,9 @@ def promote(self, type_="pyFAI.azimuthalIntegrator.AzimuthalIntegrator", kwargs= """ GeometryClass = self.__class__.__mro__[-2] # actually pyFAI.geometry.core.Geometry if isinstance(type_, str): + if "." not in type_: + if type_ in self.PROMOTION: + type_ = self.PROMOTION[type_] import importlib modules = type_.split(".") module_name = ".".join(modules[:-1]) @@ -2107,7 +2113,7 @@ def promote(self, type_="pyFAI.azimuthalIntegrator.AzimuthalIntegrator", kwargs= elif isinstance(type_, type): klass = type_ else: - raise ValueError("`type_` must be a class (or a fully qualified class name) of a Geometry derived class") + raise ValueError("`type_` must be a class (or class name) of a Geometry derived class") if kwargs == None: kwargs = {} diff --git a/src/pyFAI/test/test_geometry.py b/src/pyFAI/test/test_geometry.py index b08d5d603..08482b8ad 100644 --- a/src/pyFAI/test/test_geometry.py +++ b/src/pyFAI/test/test_geometry.py @@ -522,7 +522,9 @@ def test_promotion(self): idmask = id(g.detector.mask) ai = g.promote() self.assertEqual(type(ai).__name__, "AzimuthalIntegrator", "Promote to AzimuthalIntegrator by default") - gr = g.promote("pyFAI.geometryRefinement.GeometryRefinement") + ai = g.promote("FiberIntegrator") + self.assertEqual(type(ai).__name__, "FiberIntegrator", "Promote to FiberIntegrator when requested") + gr = g.promote("GeometryRefinement") self.assertEqual(type(gr).__name__, "GeometryRefinement", "Promote to GeometryRefinement when requested") gr = ai.promote("pyFAI.geometryRefinement.GeometryRefinement") self.assertEqual(type(gr).__name__, "GeometryRefinement", "Promote to GeometryRefinement when requested")