Skip to content

Commit

Permalink
fix more leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
jurgen-lentz committed Oct 27, 2024
1 parent 5f9c367 commit ee216a7
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 90 deletions.
90 changes: 72 additions & 18 deletions amplpy/ampl.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from libc.string cimport strdup

from cpython.bool cimport PyBool_Check

from cpython cimport Py_INCREF, Py_DECREF

from numbers import Real
from ast import literal_eval
Expand Down Expand Up @@ -179,14 +180,19 @@ cdef class AMPL:
DataFrame capturing the output of the display
command in tabular form.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_DATAFRAME* data
cdef char** statements_c = <char**> malloc(len(statements) * sizeof(char*))
for i in range(len(statements)):
statements_c[i] = strdup(statements[i].encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_GetData(self._c_ampl, statements_c, len(statements), &data))
errorinfo = campl.AMPL_GetData(self._c_ampl, statements_c, len(statements), &data)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
for i in range(len(statements)):
free(statements_c[i])
free(statements_c)
if rc != campl.AMPL_OK:
PY_AMPL_CALL(errorinfo)
return DataFrame.create(data)

def get_entity(self, name):
Expand All @@ -204,7 +210,7 @@ cdef class AMPL:
The AMPL entity with the specified name.
"""
cdef char* name_c = strdup(name.encode('utf-8'))
return Entity.create(self._c_ampl, name_c, NULL)
return Entity.create(self._c_ampl, name_c, NULL, None)

def get_variable(self, name):
"""
Expand All @@ -216,11 +222,19 @@ cdef class AMPL:
Raises:
KeyError: if the specified variable does not exist.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_ENTITYTYPE entitytype
cdef char* name_c = strdup(name.encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype))
if entitytype != campl.AMPL_VARIABLE: raiseKeyError(campl.AMPL_VARIABLE, name)
return Variable.create(self._c_ampl, name_c, NULL)
errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
if rc != campl.AMPL_OK:
free(name_c)
PY_AMPL_CALL(errorinfo)
if entitytype != campl.AMPL_VARIABLE:
free(name_c)
raiseKeyError(campl.AMPL_VARIABLE, name)
return Variable.create(self._c_ampl, name_c, NULL, None)

def get_constraint(self, name):
"""
Expand All @@ -232,11 +246,19 @@ cdef class AMPL:
Raises:
KeyError: if the specified constraint does not exist.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_ENTITYTYPE entitytype
cdef char* name_c = strdup(name.encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype))
if entitytype != campl.AMPL_CONSTRAINT: raiseKeyError(campl.AMPL_CONSTRAINT, name)
return Constraint.create(self._c_ampl, name_c, NULL)
errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
if rc != campl.AMPL_OK:
free(name_c)
PY_AMPL_CALL(errorinfo)
if entitytype != campl.AMPL_CONSTRAINT:
free(name_c)
raiseKeyError(campl.AMPL_CONSTRAINT, name)
return Constraint.create(self._c_ampl, name_c, NULL, None)

def get_objective(self, name):
"""
Expand All @@ -248,11 +270,19 @@ cdef class AMPL:
Raises:
KeyError: if the specified objective does not exist.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_ENTITYTYPE entitytype
cdef char* name_c = strdup(name.encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype))
if entitytype != campl.AMPL_OBJECTIVE: raiseKeyError(campl.AMPL_OBJECTIVE, name)
return Objective.create(self._c_ampl, name_c, NULL)
errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
if rc != campl.AMPL_OK:
free(name_c)
PY_AMPL_CALL(errorinfo)
if entitytype != campl.AMPL_OBJECTIVE:
free(name_c)
raiseKeyError(campl.AMPL_OBJECTIVE, name)
return Objective.create(self._c_ampl, name_c, NULL, None)

def get_set(self, name):
"""
Expand All @@ -264,11 +294,19 @@ cdef class AMPL:
Raises:
KeyError: if the specified set does not exist.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_ENTITYTYPE entitytype
cdef char* name_c = strdup(name.encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype))
if entitytype != campl.AMPL_SET: raiseKeyError(campl.AMPL_SET, name)
return Set.create(self._c_ampl, name_c, NULL)
errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
if rc != campl.AMPL_OK:
free(name_c)
PY_AMPL_CALL(errorinfo)
if entitytype != campl.AMPL_SET:
free(name_c)
raiseKeyError(campl.AMPL_SET, name)
return Set.create(self._c_ampl, name_c, NULL, None)

def get_parameter(self, name):
"""
Expand All @@ -280,11 +318,19 @@ cdef class AMPL:
Raises:
KeyError: if the specified parameter does not exist.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_ENTITYTYPE entitytype
cdef char* name_c = strdup(name.encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype))
if entitytype != campl.AMPL_PARAMETER: raiseKeyError(campl.AMPL_PARAMETER, name)
return Parameter.create(self._c_ampl, name_c, NULL)
errorinfo = campl.AMPL_EntityGetType(self._c_ampl, name_c, &entitytype)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
if rc != campl.AMPL_OK:
free(name_c)
PY_AMPL_CALL(errorinfo)
if entitytype != campl.AMPL_PARAMETER:
free(name_c)
raiseKeyError(campl.AMPL_PARAMETER, name)
return Parameter.create(self._c_ampl, name_c, NULL, None)

def eval(self, statements):
"""
Expand Down Expand Up @@ -591,6 +637,8 @@ cdef class AMPL:
Args:
ampl_expressions: Expressions to be evaluated.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
exprs = list(map(str, ampl_expressions))
cdef int size = len(exprs)
cdef const char** array = <const char**>malloc(size * sizeof(const char*))
Expand All @@ -601,7 +649,13 @@ cdef class AMPL:
array[i] = strdup(exprs[i].encode('utf-8'))

display = "display"
PY_AMPL_CALL(campl.AMPL_CallVisualisationCommandOnNames(self._c_ampl, display.encode('utf-8'), <const char* const*>array, size))
errorinfo = campl.AMPL_CallVisualisationCommandOnNames(self._c_ampl, display.encode('utf-8'), <const char* const*>array, size)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
for i in range(size):
free(array[i])
free(array)
if rc != campl.AMPL_OK:
PY_AMPL_CALL(errorinfo)

def set_output_handler(self, output_handler):
"""
Expand Down
5 changes: 4 additions & 1 deletion amplpy/constraint.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ cdef class Constraint(Entity):
and the :class:`~amplpy.DataFrame` class.
"""
@staticmethod
cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index):
cdef create(campl.AMPL* ampl_c, char* name, campl.AMPL_TUPLE* index, parent):
entity = Constraint()
entity._c_ampl = ampl_c
entity._name = name
entity._index = index
entity.wrap_function = campl.AMPL_CONSTRAINT
entity._entity = parent
if entity._entity is not None:
Py_INCREF(entity._entity)
return entity

def __setitem__(self, index, value):
Expand Down
13 changes: 10 additions & 3 deletions amplpy/dataframe.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -308,25 +308,32 @@ cdef class DataFrame(object):
values: The values to set.
"""
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef double* c_double_array = NULL
cdef char** c_string_array = NULL
cdef size_t size = len(values)
if isinstance(values[0], Real):
c_double_array = <double*> malloc(size * sizeof(double))
for i in range(size):
c_double_array[i] = values[i]
PY_AMPL_CALL(campl.AMPL_DataFrameSetColumnArgDouble(self._c_df, header.encode('utf-8'), c_double_array, size))
errorinfo = campl.AMPL_DataFrameSetColumnArgDouble(self._c_df, header.encode('utf-8'), c_double_array, size)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
free(c_double_array)
if rc != campl.AMPL_OK:
PY_AMPL_CALL(errorinfo)
elif isinstance(values[0], str):
c_string_array = <char**> malloc(size * sizeof(char*))
for i in range(size):
c_string_array[i] = strdup(values[i].encode('utf-8'))
PY_AMPL_CALL(campl.AMPL_DataFrameSetColumnArgString(self._c_df, header.encode('utf-8'), c_string_array, size))
errorinfo = campl.AMPL_DataFrameSetColumnArgString(self._c_df, header.encode('utf-8'), c_string_array, size)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
for i in range(size):
if c_string_array[i] != NULL:
free(c_string_array[i])
free(c_string_array)
#self._impl.setColumnPyList(header, list(values))
if rc != campl.AMPL_OK:
PY_AMPL_CALL(errorinfo)

def _get_row(self, key):
"""
Expand Down
47 changes: 31 additions & 16 deletions amplpy/entity.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,35 @@ cdef class Entity(object):
cdef char* _name
cdef campl.AMPL_TUPLE* _index
cdef campl.AMPL_ENTITYTYPE wrap_function
cdef object _entity

@staticmethod
cdef create(campl.AMPL* ampl_c, char *name, campl.AMPL_TUPLE* index):
cdef create(campl.AMPL* ampl_c, char *name, campl.AMPL_TUPLE* index, parent):
cdef campl.AMPL_ERRORINFO* errorinfo
cdef campl.AMPL_RETCODE rc
cdef campl.AMPL_ENTITYTYPE entitytype
PY_AMPL_CALL(campl.AMPL_EntityGetType(ampl_c, name, &entitytype))
errorinfo = campl.AMPL_EntityGetType(ampl_c, name, &entitytype)
rc = campl.AMPL_ErrorInfoGetError(errorinfo)
if rc != campl.AMPL_OK:
free(name)
PY_AMPL_CALL(errorinfo)
entity = Entity()
entity._c_ampl = ampl_c
entity._name = name
entity._index = index
entity.wrap_function = campl.AMPL_UNDEFINED
entity.wrap_function = entitytype
entity._entity = parent
if entity._entity is not None:
Py_INCREF(entity._entity)
return entity

#def __dealloc__(self):
# if self._name is not NULL:
# campl.AMPL_StringFree(&self._name)
#if self._index is not NULL:
# campl.AMPL_TupleFree(&self._index)
def __dealloc__(self):
if self._entity is not None:
Py_DECREF(self._entity)
if self._index is not NULL:
campl.AMPL_TupleFree(&self._index)
else:
campl.AMPL_StringFree(&self._name)

def to_string(self):
cdef char* output_c
Expand All @@ -66,13 +78,13 @@ cdef class Entity(object):

def __iter__(self):
assert self.wrap_function is not None
return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function)
return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function, self)

def __getitem__(self, index):
if not isinstance(index, (tuple, list)):
index = [index]
cdef campl.AMPL_TUPLE* tuple_c = to_c_tuple(index)
return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c)
return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c, self)

def get(self, *index):
"""
Expand All @@ -88,16 +100,16 @@ cdef class Entity(object):
index = index[0]
index = list(index)
if len(index) == 0:
return create_entity(self.wrap_function, self._c_ampl, self._name, NULL)
return create_entity(self.wrap_function, self._c_ampl, self._name, NULL, None)
else:
tuple_c = to_c_tuple(index)
if self.wrap_function == campl.AMPL_PARAMETER:
campl.AMPL_InstanceGetName(self._c_ampl, self._name, tuple_c, &name_c)
entity = create_entity(self.wrap_function, self._c_ampl, name_c, NULL).value()
campl.AMPL_StringFree(&name_c)
campl.AMPL_TupleFree(&tuple_c)
entity = create_entity(self.wrap_function, self._c_ampl, name_c, NULL, None).value()
return entity
else:
return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c)
return create_entity(self.wrap_function, self._c_ampl, self._name, tuple_c, self)

def find(self, index):
"""
Expand All @@ -107,6 +119,7 @@ cdef class Entity(object):
The wanted instance if found, otherwise it returns `None`.
"""
assert self.wrap_function is not None
cdef size_t i
cdef campl.AMPL_TUPLE* index_c = to_c_tuple(index)
cdef campl.AMPL_TUPLE** indices_c
cdef size_t size
Expand All @@ -116,7 +129,9 @@ cdef class Entity(object):
for j in range(size):
campl.AMPL_TupleFree(&indices_c[j])
free(indices_c)
return create_entity(self.wrap_function, self._c_ampl, self._name, index_c)
return create_entity(self.wrap_function, self._c_ampl, self._name, index_c, self)
for i in range(size):
campl.AMPL_TupleFree(&indices_c[i])
free(indices_c)
campl.AMPL_TupleFree(&index_c)
return None
Expand All @@ -125,7 +140,7 @@ cdef class Entity(object):
"""
Get all the instances in this entity.
"""
return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function)
return InstanceIterator.create(self._c_ampl, self._name, self.wrap_function, self)

def name(self):
"""
Expand Down
Loading

0 comments on commit ee216a7

Please sign in to comment.