From ee216a72303253b24b9f731a400decef187e945f Mon Sep 17 00:00:00 2001 From: Jurgen Lentz Date: Sun, 27 Oct 2024 21:41:00 +0100 Subject: [PATCH] fix more leaks --- amplpy/ampl.pyx | 90 ++++++++++++++++++++++++++++------- amplpy/constraint.pxi | 5 +- amplpy/dataframe.pxi | 13 +++-- amplpy/entity.pxi | 47 ++++++++++++------- amplpy/iterators.pxi | 107 ++++++++++++++++++++++++++---------------- amplpy/objective.pxi | 5 +- amplpy/parameter.pxi | 9 +++- amplpy/set.pxi | 8 +++- amplpy/util.pxi | 14 +++--- amplpy/variable.pxi | 5 +- 10 files changed, 213 insertions(+), 90 deletions(-) diff --git a/amplpy/ampl.pyx b/amplpy/ampl.pyx index 405da34..8da1595 100644 --- a/amplpy/ampl.pyx +++ b/amplpy/ampl.pyx @@ -8,6 +8,7 @@ from libc.string cimport strdup from cpython.bool cimport PyBool_Check +from cpython cimport Py_INCREF, Py_DECREF from numbers import Real from ast import literal_eval @@ -179,14 +180,19 @@ cdef class AMPL: DataFrame capturing the output of the display command in tabular form. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_DATAFRAME* data cdef char** statements_c = malloc(len(statements) * sizeof(char*)) for i in range(len(statements)): statements_c[i] = strdup(statements[i].encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_GetData(self._c_ampl, statements_c, len(statements), &data)) + errorinfo = campl.AMPL_GetData(self._c_ampl, statements_c, len(statements), &data) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) for i in range(len(statements)): free(statements_c[i]) free(statements_c) + if rc != campl.AMPL_OK: + PY_AMPL_CALL(errorinfo) return DataFrame.create(data) def get_entity(self, name): @@ -204,7 +210,7 @@ cdef class AMPL: The AMPL entity with the specified name. """ cdef char* name_c = strdup(name.encode('utf-8')) - return Entity.create(self._c_ampl, name_c, NULL) + return Entity.create(self._c_ampl, name_c, NULL, None) def get_variable(self, name): """ @@ -216,11 +222,19 @@ cdef class AMPL: Raises: KeyError: if the specified variable does not exist. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype cdef char* name_c = strdup(name.encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)) - if entitytype != campl.AMPL_VARIABLE: raiseKeyError(campl.AMPL_VARIABLE, name) - return Variable.create(self._c_ampl, name_c, NULL) + errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name_c) + PY_AMPL_CALL(errorinfo) + if entitytype != campl.AMPL_VARIABLE: + free(name_c) + raiseKeyError(campl.AMPL_VARIABLE, name) + return Variable.create(self._c_ampl, name_c, NULL, None) def get_constraint(self, name): """ @@ -232,11 +246,19 @@ cdef class AMPL: Raises: KeyError: if the specified constraint does not exist. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype cdef char* name_c = strdup(name.encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)) - if entitytype != campl.AMPL_CONSTRAINT: raiseKeyError(campl.AMPL_CONSTRAINT, name) - return Constraint.create(self._c_ampl, name_c, NULL) + errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name_c) + PY_AMPL_CALL(errorinfo) + if entitytype != campl.AMPL_CONSTRAINT: + free(name_c) + raiseKeyError(campl.AMPL_CONSTRAINT, name) + return Constraint.create(self._c_ampl, name_c, NULL, None) def get_objective(self, name): """ @@ -248,11 +270,19 @@ cdef class AMPL: Raises: KeyError: if the specified objective does not exist. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype cdef char* name_c = strdup(name.encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)) - if entitytype != campl.AMPL_OBJECTIVE: raiseKeyError(campl.AMPL_OBJECTIVE, name) - return Objective.create(self._c_ampl, name_c, NULL) + errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name_c) + PY_AMPL_CALL(errorinfo) + if entitytype != campl.AMPL_OBJECTIVE: + free(name_c) + raiseKeyError(campl.AMPL_OBJECTIVE, name) + return Objective.create(self._c_ampl, name_c, NULL, None) def get_set(self, name): """ @@ -264,11 +294,19 @@ cdef class AMPL: Raises: KeyError: if the specified set does not exist. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype cdef char* name_c = strdup(name.encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)) - if entitytype != campl.AMPL_SET: raiseKeyError(campl.AMPL_SET, name) - return Set.create(self._c_ampl, name_c, NULL) + errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name_c) + PY_AMPL_CALL(errorinfo) + if entitytype != campl.AMPL_SET: + free(name_c) + raiseKeyError(campl.AMPL_SET, name) + return Set.create(self._c_ampl, name_c, NULL, None) def get_parameter(self, name): """ @@ -280,11 +318,19 @@ cdef class AMPL: Raises: KeyError: if the specified parameter does not exist. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype cdef char* name_c = strdup(name.encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)) - if entitytype != campl.AMPL_PARAMETER: raiseKeyError(campl.AMPL_PARAMETER, name) - return Parameter.create(self._c_ampl, name_c, NULL) + errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name_c) + PY_AMPL_CALL(errorinfo) + if entitytype != campl.AMPL_PARAMETER: + free(name_c) + raiseKeyError(campl.AMPL_PARAMETER, name) + return Parameter.create(self._c_ampl, name_c, NULL, None) def eval(self, statements): """ @@ -591,6 +637,8 @@ cdef class AMPL: Args: ampl_expressions: Expressions to be evaluated. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc exprs = list(map(str, ampl_expressions)) cdef int size = len(exprs) cdef const char** array = malloc(size * sizeof(const char*)) @@ -601,7 +649,13 @@ cdef class AMPL: array[i] = strdup(exprs[i].encode('utf-8')) display = "display" - PY_AMPL_CALL(campl.AMPL_CallVisualisationCommandOnNames(self._c_ampl, display.encode('utf-8'), array, size)) + errorinfo = campl.AMPL_CallVisualisationCommandOnNames(self._c_ampl, display.encode('utf-8'), array, size) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + for i in range(size): + free(array[i]) + free(array) + if rc != campl.AMPL_OK: + PY_AMPL_CALL(errorinfo) def set_output_handler(self, output_handler): """ diff --git a/amplpy/constraint.pxi b/amplpy/constraint.pxi index 25f76cf..f39987b 100644 --- a/amplpy/constraint.pxi +++ b/amplpy/constraint.pxi @@ -30,12 +30,15 @@ cdef class Constraint(Entity): and the :class:`~amplpy.DataFrame` class. """ @staticmethod - cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index, parent): entity = Constraint() entity._c_ampl = ampl_c entity._name = name entity._index = index entity.wrap_function = campl.AMPL_CONSTRAINT + entity._entity = parent + if entity._entity is not None: + Py_INCREF(entity._entity) return entity def __setitem__(self, index, value): diff --git a/amplpy/dataframe.pxi b/amplpy/dataframe.pxi index c2a0cc1..633f067 100644 --- a/amplpy/dataframe.pxi +++ b/amplpy/dataframe.pxi @@ -308,6 +308,8 @@ cdef class DataFrame(object): values: The values to set. """ + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef double* c_double_array = NULL cdef char** c_string_array = NULL cdef size_t size = len(values) @@ -315,18 +317,23 @@ cdef class DataFrame(object): c_double_array = malloc(size * sizeof(double)) for i in range(size): c_double_array[i] = values[i] - PY_AMPL_CALL(campl.AMPL_DataFrameSetColumnArgDouble(self._c_df, header.encode('utf-8'), c_double_array, size)) + errorinfo = campl.AMPL_DataFrameSetColumnArgDouble(self._c_df, header.encode('utf-8'), c_double_array, size) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) free(c_double_array) + if rc != campl.AMPL_OK: + PY_AMPL_CALL(errorinfo) elif isinstance(values[0], str): c_string_array = malloc(size * sizeof(char*)) for i in range(size): c_string_array[i] = strdup(values[i].encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_DataFrameSetColumnArgString(self._c_df, header.encode('utf-8'), c_string_array, size)) + errorinfo = campl.AMPL_DataFrameSetColumnArgString(self._c_df, header.encode('utf-8'), c_string_array, size) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) for i in range(size): if c_string_array[i] != NULL: free(c_string_array[i]) free(c_string_array) - #self._impl.setColumnPyList(header, list(values)) + if rc != campl.AMPL_OK: + PY_AMPL_CALL(errorinfo) def _get_row(self, key): """ diff --git a/amplpy/entity.pxi b/amplpy/entity.pxi index 4f8eadc..55bcdad 100644 --- a/amplpy/entity.pxi +++ b/amplpy/entity.pxi @@ -36,23 +36,35 @@ cdef class Entity(object): cdef char* _name cdef campl.AMPL_TUPLE* _index cdef campl.AMPL_ENTITYTYPE wrap_function + cdef object _entity @staticmethod - cdef create(campl.AMPL* ampl_c, char *name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl_c, char *name, campl.AMPL_TUPLE* index, parent): + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype - PY_AMPL_CALL(campl.AMPL_EntityGetType(ampl_c, name, &entitytype)) + errorinfo = campl.AMPL_EntityGetType(ampl_c, name, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name) + PY_AMPL_CALL(errorinfo) entity = Entity() entity._c_ampl = ampl_c entity._name = name entity._index = index - entity.wrap_function = campl.AMPL_UNDEFINED + entity.wrap_function = entitytype + entity._entity = parent + if entity._entity is not None: + Py_INCREF(entity._entity) return entity - #def __dealloc__(self): - # if self._name is not NULL: - # campl.AMPL_StringFree(&self._name) - #if self._index is not NULL: - # campl.AMPL_TupleFree(&self._index) + def __dealloc__(self): + if self._entity is not None: + Py_DECREF(self._entity) + if self._index is not NULL: + campl.AMPL_TupleFree(&self._index) + else: + campl.AMPL_StringFree(&self._name) def to_string(self): cdef char* output_c @@ -66,13 +78,13 @@ cdef class Entity(object): def __iter__(self): assert self.wrap_function is not None - return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function) + return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function, self) def __getitem__(self, index): if not isinstance(index, (tuple, list)): index = [index] cdef campl.AMPL_TUPLE* tuple_c = to_c_tuple(index) - return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c) + return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c, self) def get(self, *index): """ @@ -88,16 +100,16 @@ cdef class Entity(object): index = index[0] index = list(index) if len(index) == 0: - return create_entity(self.wrap_function, self._c_ampl, self._name, NULL) + return create_entity(self.wrap_function, self._c_ampl, self._name, NULL, None) else: tuple_c = to_c_tuple(index) if self.wrap_function == campl.AMPL_PARAMETER: campl.AMPL_InstanceGetName(self._c_ampl, self._name, tuple_c, &name_c) - entity = create_entity(self.wrap_function, self._c_ampl, name_c, NULL).value() - campl.AMPL_StringFree(&name_c) + campl.AMPL_TupleFree(&tuple_c) + entity = create_entity(self.wrap_function, self._c_ampl, name_c, NULL, None).value() return entity else: - return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c) + return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c, self) def find(self, index): """ @@ -107,6 +119,7 @@ cdef class Entity(object): The wanted instance if found, otherwise it returns `None`. """ assert self.wrap_function is not None + cdef size_t i cdef campl.AMPL_TUPLE* index_c = to_c_tuple(index) cdef campl.AMPL_TUPLE** indices_c cdef size_t size @@ -116,7 +129,9 @@ cdef class Entity(object): for j in range(size): campl.AMPL_TupleFree(&indices_c[j]) free(indices_c) - return create_entity(self.wrap_function, self._c_ampl, self._name, index_c) + return create_entity(self.wrap_function, self._c_ampl, self._name, index_c, self) + for i in range(size): + campl.AMPL_TupleFree(&indices_c[i]) free(indices_c) campl.AMPL_TupleFree(&index_c) return None @@ -125,7 +140,7 @@ cdef class Entity(object): """ Get all the instances in this entity. """ - return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function) + return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function, self) def name(self): """ diff --git a/amplpy/iterators.pxi b/amplpy/iterators.pxi index 5b47551..bc29db8 100644 --- a/amplpy/iterators.pxi +++ b/amplpy/iterators.pxi @@ -40,56 +40,73 @@ cdef class EntityMap(object): cdef campl.AMPL_ENTITYTYPE entity_class cdef char** begin cdef char** end - cdef char** iterator + cdef size_t iterator cdef size_t _size @staticmethod cdef create(campl.AMPL* ampl, campl.AMPL_ENTITYTYPE entity_class): + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc entityit = EntityMap() entityit._c_ampl = ampl entityit.entity_class = entity_class if entity_class == campl.AMPL_VARIABLE: - campl.AMPL_GetVariables(entityit._c_ampl, &entityit._size, &entityit.begin) + errorinfo = campl.AMPL_GetVariables(entityit._c_ampl, &entityit._size, &entityit.begin) elif entity_class == campl.AMPL_CONSTRAINT: - campl.AMPL_GetConstraints(entityit._c_ampl, &entityit._size, &entityit.begin) + errorinfo = campl.AMPL_GetConstraints(entityit._c_ampl, &entityit._size, &entityit.begin) elif entity_class == campl.AMPL_OBJECTIVE: - campl.AMPL_GetObjectives(entityit._c_ampl, &entityit._size, &entityit.begin) + errorinfo = campl.AMPL_GetObjectives(entityit._c_ampl, &entityit._size, &entityit.begin) elif entity_class == campl.AMPL_SET: - campl.AMPL_GetSets(entityit._c_ampl, &entityit._size, &entityit.begin) + errorinfo = campl.AMPL_GetSets(entityit._c_ampl, &entityit._size, &entityit.begin) elif entity_class == campl.AMPL_PARAMETER: - campl.AMPL_GetParameters(entityit._c_ampl, &entityit._size, &entityit.begin) + errorinfo = campl.AMPL_GetParameters(entityit._c_ampl, &entityit._size, &entityit.begin) else: raise ValueError(f"Unknown entity class.") - entityit.iterator = entityit.begin + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + for i in range(entityit._size): + campl.AMPL_StringFree(&entityit.begin[i]) + free(entityit.begin) + PY_AMPL_CALL(errorinfo) + + entityit.iterator = 0 entityit.end = entityit.begin + entityit._size return entityit def __dealloc__(self): + if self.iterator < self._size: + for i in range(self.iterator, self._size): + campl.AMPL_StringFree(&self.begin[i]) if self.begin != NULL: - for i in range(self._size): - free(self.begin[i]) free(self.begin) + self.begin = NULL def __iter__(self): - self.iterator = self.begin return self def __next__(self): - if self.iterator >= self.end: + if self.iterator >= self._size: raise StopIteration - cdef char** it = self.iterator + tuple = (self.begin[self.iterator].decode('utf-8'), create_entity(self.entity_class, self._c_ampl, self.begin[self.iterator], NULL, None)) self.iterator += 1 - name = it[0] - return (name, create_entity(self.entity_class, self._c_ampl, name, NULL)) + return tuple def __getitem__(self, key): assert isinstance(key, str) + cdef campl.AMPL_ERRORINFO* errorinfo + cdef campl.AMPL_RETCODE rc cdef campl.AMPL_ENTITYTYPE entitytype cdef char* name_c = strdup(key.encode('utf-8')) - PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)) - if entitytype != self.entity_class: raiseKeyError(self.entity_class, name_c) - return create_entity(self.entity_class, self._c_ampl, name_c, NULL) + errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype) + rc = campl.AMPL_ErrorInfoGetError(errorinfo) + if rc != campl.AMPL_OK: + free(name_c) + PY_AMPL_CALL(errorinfo) + if entitytype != self.entity_class: + free(name_c) + raiseKeyError(self.entity_class, key) + return create_entity(self.entity_class, self._c_ampl, name_c, NULL, None) def size(self): return int(self._size) @@ -98,7 +115,6 @@ cdef class EntityMap(object): return self.size() cdef class InstanceIterator(object): - cdef campl.AMPL* _c_ampl cdef char* _name cdef campl.AMPL_ENTITYTYPE entity_class @@ -106,14 +122,16 @@ cdef class InstanceIterator(object): cdef campl.AMPL_TUPLE** end cdef size_t iterator cdef size_t _size + cdef object _entity @staticmethod - cdef create(campl.AMPL* ampl, char* name, campl.AMPL_ENTITYTYPE entity_class): + cdef create(campl.AMPL* ampl, char* name, campl.AMPL_ENTITYTYPE entity_class, parent): instanceit = InstanceIterator() cdef size_t arity instanceit._c_ampl = ampl instanceit._name = name instanceit.entity_class = entity_class + instanceit._entity = parent campl.AMPL_EntityGetIndexarity(instanceit._c_ampl, instanceit._name, &arity) if arity == 0: instanceit._size = 1 @@ -131,10 +149,12 @@ cdef class InstanceIterator(object): return instanceit def __dealloc__(self): - if self.begin != NULL: - for i in range(self._size): + if self.iterator+1 < self._size: + for i in range(self.iterator+1, self._size): campl.AMPL_TupleFree(&self.begin[i]) + if self.begin != NULL: free(self.begin) + self.begin = NULL def __iter__(self): return self @@ -144,15 +164,15 @@ cdef class InstanceIterator(object): raise StopIteration self.iterator += 1 if self.begin == NULL: - return (None, create_entity(self.entity_class, self._c_ampl, self._name, NULL)) + return (None, self._entity) else: - return (to_py_tuple(self.begin[self.iterator]), create_entity(self.entity_class, self._c_ampl, self._name, self.begin[self.iterator])) + return (to_py_tuple(self.begin[self.iterator]), create_entity(self.entity_class, self._c_ampl, self._name, self.begin[self.iterator], self._entity)) def __getitem__(self, key): assert isinstance(key, str) key = tuple(key) cdef campl.AMPL_TUPLE* tuple_c = to_c_tuple(key) - return create_entity(self.entity_class, self._c_ampl, self._name, tuple_c) + return create_entity(self.entity_class, self._c_ampl, self._name, tuple_c, self._entity) def size(self): return int(self._size) @@ -168,33 +188,38 @@ cdef class MemberRangeIterator(object): cdef campl.AMPL_TUPLE* _index cdef campl.AMPL_TUPLE** begin cdef campl.AMPL_TUPLE** end - cdef campl.AMPL_TUPLE** iterator + cdef size_t iterator cdef size_t _size + cdef object _entity @staticmethod - cdef create(campl.AMPL* ampl, char* name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl, char* name, campl.AMPL_TUPLE* index, parent): instanceit = MemberRangeIterator() instanceit._c_ampl = ampl instanceit._name = name instanceit._index = index + instanceit._entity = parent + if instanceit._entity is not None: + Py_INCREF(instanceit._entity) campl.AMPL_SetInstanceGetValues(instanceit._c_ampl, instanceit._name, instanceit._index, &instanceit.begin, &instanceit._size) + instanceit.iterator = 0 if instanceit._size == 0: - instanceit.iterator = NULL instanceit.end = NULL else: - instanceit.iterator = instanceit.begin instanceit.end = instanceit.begin + instanceit._size return instanceit def __dealloc__(self): + if self._entity is not None: + Py_DECREF(self._entity) for i in range(self._size): campl.AMPL_TupleFree(&self.begin[i]) - free(self.begin) + if self.begin != NULL: + free(self.begin) + self.begin = NULL def size(self): - cdef size_t size - campl.AMPL_SetInstanceGetSize(self._c_ampl, self._name, self._index, &size) - return int(size) + return int(self._size) def __len__(self): return self.size() @@ -203,20 +228,22 @@ cdef class MemberRangeIterator(object): return self def __next__(self): - if self.iterator >= self.end or self.iterator == NULL: + if self.iterator >= self._size: raise StopIteration - cdef campl.AMPL_TUPLE** it = self.iterator + cdef campl.AMPL_TUPLE* it = self.begin[self.iterator] cdef size_t size cdef campl.AMPL_VARIANT* variant - campl.AMPL_TupleGetSize(it[0], &size) + campl.AMPL_TupleGetSize(it, &size) self.iterator += 1 if size == 1: - campl.AMPL_TupleGetVariant(it[0], 0, &variant) - return to_py_variant(variant) + campl.AMPL_TupleGetVariant(it, 0, &variant) + py_variant = to_py_variant(variant) + self.iterator += 1 + return py_variant else: - return to_py_tuple(it[0]) - - + py_tuple = to_py_tuple(it) + self.iterator += 1 + return py_tuple cdef class ColIterator(object): cdef campl.AMPL_DATAFRAME* _df diff --git a/amplpy/objective.pxi b/amplpy/objective.pxi index 5724f25..651006e 100644 --- a/amplpy/objective.pxi +++ b/amplpy/objective.pxi @@ -22,12 +22,15 @@ cdef class Objective(Entity): and the :class:`~amplpy.DataFrame` class. """ @staticmethod - cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index, parent): entity = Objective() entity._c_ampl = ampl_c entity._name = name entity._index = index entity.wrap_function = campl.AMPL_OBJECTIVE + entity._parent = parent + if entity._entity is not None: + Py_INCREF(entity._entity) return entity def value(self): diff --git a/amplpy/parameter.pxi b/amplpy/parameter.pxi index fa83d34..5e83bbc 100644 --- a/amplpy/parameter.pxi +++ b/amplpy/parameter.pxi @@ -29,12 +29,15 @@ cdef class Parameter(Entity): and an object of class :class:`~amplpy.DataFrame`. """ @staticmethod - cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index, parent): entity = Parameter() entity._c_ampl = ampl_c entity._name = name entity._index = index entity.wrap_function = campl.AMPL_PARAMETER + entity._entity = parent + if entity._entity is not None: + Py_INCREF(entity._entity) return entity def __setitem__(self, index, value): @@ -78,6 +81,7 @@ cdef class Parameter(Entity): cdef char* expression cdef campl.AMPL_VARIANT* v campl.AMPL_InstanceGetName(self._c_ampl, self._name, tuple_c, &expression) + campl.AMPL_TupleFree(&tuple_c) campl.AMPL_GetValue(self._c_ampl, expression, &v) campl.AMPL_StringFree(&expression) py_variant = to_py_variant(v) @@ -127,9 +131,12 @@ cdef class Parameter(Entity): index_c = to_c_tuple(index) if isinstance(value, Real): campl.AMPL_ParameterInstanceSetNumericValue(self._c_ampl, self._name, index_c, float(value)) + campl.AMPL_TupleFree(&index_c) elif isinstance(value, str): campl.AMPL_ParameterInstanceSetStringValue(self._c_ampl, self._name, index_c, value.encode('utf-8')) + campl.AMPL_TupleFree(&index_c) else: + campl.AMPL_TupleFree(&index_c) raise TypeError def set_values(self, values): diff --git a/amplpy/set.pxi b/amplpy/set.pxi index c3b23d0..5a85b13 100644 --- a/amplpy/set.pxi +++ b/amplpy/set.pxi @@ -29,12 +29,15 @@ cdef class Set(Entity): :class:`~amplpy.DataFrame`. """ @staticmethod - cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index, parent): entity = Set() entity._c_ampl = ampl_c entity._name = name entity._index = index entity.wrap_function = campl.AMPL_SET + entity._entity = parent + if entity._entity is not None: + Py_INCREF(entity._entity) return entity def __setitem__(self, index, value): @@ -73,7 +76,7 @@ cdef class Set(Entity): """ Get members (tuples) of this Set. Valid only for non-indexed sets. """ - return MemberRangeIterator.create(self._c_ampl, self._name, self._index) + return MemberRangeIterator.create(self._c_ampl, self._name, self._index, self) def size(self): """ @@ -94,6 +97,7 @@ cdef class Set(Entity): cdef bool_c contains_c cdef campl.AMPL_TUPLE* t_c = to_c_tuple(t) campl.AMPL_SetInstanceContains(self._c_ampl, self._name, NULL, t_c, &contains_c) + campl.AMPL_TupleFree(&t_c) return contains_c def set_values(self, values): diff --git a/amplpy/util.pxi b/amplpy/util.pxi index 3ef3f85..6dd7959 100644 --- a/amplpy/util.pxi +++ b/amplpy/util.pxi @@ -135,19 +135,19 @@ cdef campl.AMPL_VARIANT* to_c_variant(value): raise ValueError(f"unsupported type {type(value)}") return variant -cdef create_entity(campl.AMPL_ENTITYTYPE entity_class, campl.AMPL* ampl, char* name, campl.AMPL_TUPLE* index): +cdef create_entity(campl.AMPL_ENTITYTYPE entity_class, campl.AMPL* ampl, char* name, campl.AMPL_TUPLE* index, parent): if entity_class == campl.AMPL_VARIABLE: - return Variable.create(ampl, name, index) + return Variable.create(ampl, name, index, parent) elif entity_class == campl.AMPL_CONSTRAINT: - return Constraint.create(ampl, name, index) + return Constraint.create(ampl, name, index, parent) elif entity_class == campl.AMPL_OBJECTIVE: - return Objective.create(ampl, name, index) + return Objective.create(ampl, name, index, parent) elif entity_class == campl.AMPL_SET: - return Set.create(ampl, name, index) + return Set.create(ampl, name, index, parent) elif entity_class == campl.AMPL_PARAMETER: - return Parameter.create(ampl, name, index) + return Parameter.create(ampl, name, index, parent) else: - return Entity.create(ampl, name, index) + return Entity.create(ampl, name, index, parent) cdef void setValuesParamNum(campl.AMPL* ampl, char* name, values): cdef size_t size = len(values) diff --git a/amplpy/variable.pxi b/amplpy/variable.pxi index 7d25605..93e48db 100644 --- a/amplpy/variable.pxi +++ b/amplpy/variable.pxi @@ -21,12 +21,15 @@ cdef class Variable(Entity): and the :class:`~amplpy.DataFrame` class. """ @staticmethod - cdef create(campl.AMPL* ampl_c, char *name, campl.AMPL_TUPLE* index): + cdef create(campl.AMPL* ampl_c, char *name, campl.AMPL_TUPLE* index, parent): entity = Variable() entity._c_ampl = ampl_c entity._name = name entity._index = index entity.wrap_function = campl.AMPL_VARIABLE + entity._entity = parent + if entity._entity is not None: + Py_INCREF(entity._entity) return entity def __setitem__(self, index, value):