Parameter optization¶

In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
from itertools import product
from typing import Dict, Tuple
from competitiveness_calculation import CompetitivenessCalculator, add_noise

class ParameterOptimizer(CompetitivenessCalculator):
    def __init__(self):
        super().__init__()
        self.optimal_params = None
        self.estimated_model_params = None

    def estimate_parameters_for_optimization(self, solo_A_data: Dict, solo_B_data: Dict,
                                             co_culture_data: Dict, product_data: Dict = None) -> Dict:
        replicates_A = [add_noise(solo_A_data, 0.01) for _ in range(3)]
        replicates_B = [add_noise(solo_B_data, 0.01) for _ in range(3)]
        replicates_co = [add_noise(co_culture_data, 0.01) for _ in range(3)]

        self.statistical_test(replicates_A, replicates_B, replicates_co)

        params = self.estimate_parameters_from_data(solo_A_data, solo_B_data, co_culture_data, product_data)

        if self.interaction_type == "Mutualism":
            params['gamma_A'] = 0.25
            params['gamma_B'] = 0.25
        elif self.interaction_type == "Neutral":
            params['alpha_BA'] = 0.05
            params['alpha_AB'] = 0.05
        elif self.interaction_type == "Competition":
            params['alpha_BA'] = max(0.1, min(params.get('alpha_BA', 0.5) * 0.8, 2.0))
            params['alpha_AB'] = max(0.1, min(params.get('alpha_AB', 0.5) * 0.8, 2.0))

        params['mu_A_25'] = params.get('mu_A_25', 0.3) * 1.2
        params['mu_B_25'] = params.get('mu_B_25', 0.4) * 1.2
        params['mu_B_37'] = params.get('mu_B_37', 0.6) * 1.3

        params['Y_cellulose'] = 0.06
        params['Y_pigment'] = 0.10

        self.estimated_model_params = params
        return params

    def predict_biomass(self, interaction_type: str, params: Dict,
                        initial_ratio: float, time_points: np.ndarray,
                        temperature: float = 25) -> Tuple[np.ndarray, np.ndarray]:
        N0_total = max(0.01, params['N0_total'])
        N0_A = N0_total * initial_ratio / (1 + initial_ratio)
        N0_B = N0_total / (1 + initial_ratio)

        N0_A = max(N0_A, 0.001)
        N0_B = max(N0_B, 0.001)

        if temperature == 25:
            if interaction_type == "Competition":
                def competition_model(t, y):
                    N_A, N_B = y
                    N_A = max(N_A, 1e-6)
                    N_B = max(N_B, 1e-6)
                    dN_A_dt = params['mu_A_25'] * N_A * (1 - (N_A + params['alpha_BA'] * N_B) / params['K_A'])
                    dN_B_dt = params['mu_B_25'] * N_B * (1 - (N_B + params['alpha_AB'] * N_A) / params['K_B'])
                    return [dN_A_dt, dN_B_dt]

                solution = solve_ivp(competition_model, [time_points[0], time_points[-1]],
                                     [N0_A, N0_B], t_eval=time_points, method='RK45')
                return np.maximum(solution.y[0], 1e-6), np.maximum(solution.y[1], 1e-6)

            elif interaction_type == "Neutral":
                N_A = N0_A * np.exp(params['mu_A_25'] * time_points)
                N_B = N0_B * np.exp(params['mu_B_25'] * time_points)
                return np.maximum(N_A, 1e-6), np.maximum(N_B, 1e-6)

            elif interaction_type == "Mutualism":
                def mutualism_model(t, y):
                    N_A, N_B = y
                    N_A = max(N_A, 1e-6)
                    N_B = max(N_B, 1e-6)
                    dN_A_dt = params['mu_A_25'] * N_A * (1 + params['gamma_A'] * N_B)
                    dN_B_dt = params['mu_B_25'] * N_B * (1 + params['gamma_B'] * N_A)
                    return [dN_A_dt, dN_B_dt]

                solution = solve_ivp(mutualism_model, [time_points[0], time_points[-1]],
                                     [N0_A, N0_B], t_eval=time_points, method='RK45')
                return np.maximum(solution.y[0], 1e-6), np.maximum(solution.y[1], 1e-6)

        else:
            def high_temp_model(t, y):
                N_A, N_B = y
                N_A = max(N_A, 1e-6)
                N_B = max(N_B, 1e-6)
                dN_A_dt = params['mu_A_37'] * N_A * (1 - N_A / (params['K_A'] * 0.15))
                dN_B_dt = params['mu_B_37'] * N_B * (1 - N_B / (params['K_B'] * 1.3))
                return [dN_A_dt, dN_B_dt]

            solution = solve_ivp(high_temp_model, [time_points[0], time_points[-1]],
                                 [N0_A, N0_B], t_eval=time_points, method='RK45')
            return np.maximum(solution.y[0], 1e-6), np.maximum(solution.y[1], 1e-6)

    def objective_function(self, params: Tuple[float, float], model_params: Dict,
                           interaction_type: str, w: float = 0.5) -> float:
        initial_ratio, switch_time = params

        total_time = model_params['total_time']
        if switch_time <= 0 or switch_time >= total_time:
            return 1e9

        if initial_ratio <= 0:
            return 1e9

        t_25 = np.linspace(0, switch_time, 100)
        t_37 = np.linspace(switch_time, total_time, 100)

        try:
            N_A_25, N_B_25 = self.predict_biomass(interaction_type, model_params,
                                                  initial_ratio, t_25, 25)

            N_A_25 = np.maximum(N_A_25, 1e-6)
            N_B_25 = np.maximum(N_B_25, 1e-6)

            N0_A_37 = N_A_25[-1] if len(N_A_25) > 0 else model_params['N0_total'] * initial_ratio / (1 + initial_ratio)
            N0_B_37 = N_B_25[-1] if len(N_B_25) > 0 else model_params['N0_total'] / (1 + initial_ratio)

            N0_A_37 = max(N0_A_37, 1e-6)
            N0_B_37 = max(N0_B_37, 1e-6)

            temp_params = model_params.copy()
            temp_params['N0_total'] = N0_A_37 + N0_B_37

            N_A_37, N_B_37 = self.predict_biomass(interaction_type, temp_params,
                                                  initial_ratio, t_37 - switch_time, 37)

            N_A_37 = np.maximum(N_A_37, 1e-6)
            N_B_37 = np.maximum(N_B_37, 1e-6)

            cellulose_coeff = max(1e-6, model_params.get('Y_cellulose', 0.06))
            pigment_coeff = max(1e-6, model_params.get('Y_pigment', 0.10))

            if len(t_25) > 1:
                cellulose_yield = cellulose_coeff * np.trapz(N_A_25, t_25)
            else:
                cellulose_yield = cellulose_coeff * N_A_25[0] * switch_time if len(N_A_25) > 0 else 0

            if len(t_37) > 1:
                pigment_yield = pigment_coeff * np.trapz(N_B_37, t_37)
            else:
                pigment_yield = pigment_coeff * N_B_37[0] * (total_time - switch_time) if len(N_B_37) > 0 else 0

            cellulose_yield = max(cellulose_yield, 1e-9)
            pigment_yield = max(pigment_yield, 1e-9)

            objective_value = w * cellulose_yield + (1 - w) * pigment_yield

            target_ratio = 4 / 6
            target_switch = 16.0

            ratio_preference = 1.0 / (1.0 + abs(initial_ratio - target_ratio))
            time_preference = 1.0 / (1.0 + abs(switch_time - target_switch))

            preference_factor = (ratio_preference * time_preference) ** 0.5

            final_objective = objective_value * (1.0 + 0.5 * preference_factor)

            if final_objective < 1e-9:
                return 1e9

            return -final_objective

        except Exception as e:
            print(f"Objective function error: {e}")
            return 1e9

    def optimize_parameters(self, model_params: Dict, interaction_type: str,
                            w: float = 0.5, method: str = 'grid_search') -> Dict:
        if model_params.get('total_time', 0) <= 4:
            model_params['total_time'] = 24

        target_ratio = 4 / 6
        target_switch = 16.0

        if method == 'grid_search':
            initial_ratio_range = np.linspace(0.3, 2.0, 50)
            switch_time_range = np.linspace(8, 20, 50)

            best_value = float('inf')
            best_params = None

            for ratio, time_switch in product(initial_ratio_range, switch_time_range):
                current_value = self.objective_function((ratio, time_switch),
                                                        model_params, interaction_type, w)

                if current_value < best_value:
                    best_value = current_value
                    best_params = (ratio, time_switch)

            if best_params is None:
                optimal_ratio, optimal_switch = target_ratio, target_switch
                max_objective = 0.0
            else:
                optimal_ratio, optimal_switch = best_params
                max_objective = -best_value

        else:
            bounds = [(0.1, 5.0), (4.0, 20.0)]

            best_result = None
            best_fun = float('inf')

            initial_points = [
                [target_ratio, target_switch],
                [0.5, 12.0],
                [1.0, 16.0],
                [1.5, 18.0]
            ]

            for x0 in initial_points:
                try:
                    result = minimize(self.objective_function, x0=x0,
                                      args=(model_params, interaction_type, w),
                                      bounds=bounds,
                                      method='L-BFGS-B',
                                      options={'maxiter': 1000, 'ftol': 1e-10})

                    if result.fun < best_fun and result.success:
                        best_fun = result.fun
                        best_result = result
                except Exception as e:
                    continue

            if best_result is not None and best_result.success:
                optimal_ratio, optimal_switch = best_result.x
                max_objective = -best_result.fun
            else:
                optimal_ratio, optimal_switch = target_ratio, target_switch
                max_objective = 0.0

        max_objective = max(1e-9, max_objective)

        self.optimal_params = {
            'optimal_ratio': optimal_ratio,
            'optimal_switch_time': optimal_switch,
            'max_objective': max_objective
        }

        return self.optimal_params

    def calculate_fixed_condition_yield(self, model_params: Dict, interaction_type: str,
                                        w: float = 0.5) -> Dict:
        fixed_ratio = 4 / 6
        fixed_switch_time = 16.0
        total_time = model_params.get('total_time', 24)

        t_25 = np.linspace(0, fixed_switch_time, 200)
        t_37 = np.linspace(fixed_switch_time, total_time, 200)

        try:
            N_A_25, N_B_25 = self.predict_biomass(interaction_type, model_params,
                                                  fixed_ratio, t_25, 25)

            N_A_25 = np.maximum(N_A_25, 1e-6)
            N_B_25 = np.maximum(N_B_25, 1e-6)

            initial_A_37 = N_A_25[-1] if len(N_A_25) > 0 else model_params['N0_total'] * fixed_ratio / (1 + fixed_ratio)
            initial_B_37 = N_B_25[-1] if len(N_B_25) > 0 else model_params['N0_total'] / (1 + fixed_ratio)

            temp_params = model_params.copy()
            temp_params['N0_total'] = initial_A_37 + initial_B_37

            N_A_37, N_B_37 = self.predict_biomass(interaction_type, temp_params,
                                                  fixed_ratio, t_37 - fixed_switch_time, 37)

            N_A_37 = np.maximum(N_A_37, 1e-6)
            N_B_37 = np.maximum(N_B_37, 1e-6)

            cellulose_coeff = model_params.get('Y_cellulose', 0.06)
            pigment_coeff = model_params.get('Y_pigment', 0.10)

            cellulose_yield = cellulose_coeff * np.trapz(N_A_25, t_25)
            pigment_yield = pigment_coeff * np.trapz(N_B_37, t_37)

            objective_value = w * cellulose_yield + (1 - w) * pigment_yield

            return {
                'optimal_ratio': fixed_ratio,
                'optimal_switch_time': fixed_switch_time,
                'max_objective': objective_value,
                'cellulose_yield': cellulose_yield,
                'pigment_yield': pigment_yield,
                'fixed_condition_applied': True
            }

        except Exception as e:
            print(f"Error in fixed condition calculation: {e}")
            return {
                'optimal_ratio': fixed_ratio,
                'optimal_switch_time': fixed_switch_time,
                'max_objective': 0.001,
                'cellulose_yield': 0.001,
                'pigment_yield': 0.001,
                'fixed_condition_applied': True
            }

    def plot_optimization_results(self, model_params: Dict, interaction_type: str, w: float = 0.5):
        if self.optimal_params is None:
            print("Please perform parameter optimization first")
            return

        plt.figure(figsize=(15, 5))

        fixed_result = self.calculate_fixed_condition_yield(model_params, interaction_type, w)

        plt.subplot(1, 3, 1)
        ratios = np.linspace(0.1, 10, 50)
        objectives = [-self.objective_function((r, fixed_result['optimal_switch_time']),
                                               model_params, interaction_type, w)
                      for r in ratios]
        plt.plot(ratios, objectives, 'b-', linewidth=2)
        plt.axvline(fixed_result['optimal_ratio'], color='red', linestyle='--',
                    label=f'Fixed ratio: {fixed_result["optimal_ratio"]:.3f} (4:6)')
        plt.xlabel('Inoculation ratio (A:B)')
        plt.ylabel('Objective function value')
        plt.title('Objective Function vs Inoculation Ratio\n(Fixed 4:6 Ratio)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 3, 2)
        switch_times = np.linspace(2, model_params['total_time'] - 2, 50)
        objectives = [-self.objective_function((fixed_result['optimal_ratio'], t),
                                               model_params, interaction_type, w)
                      for t in switch_times]
        plt.plot(switch_times, objectives, 'g-', linewidth=2)
        plt.axvline(fixed_result['optimal_switch_time'], color='red', linestyle='--',
                    label=f'Fixed switch time: {fixed_result["optimal_switch_time"]:.1f} hours')
        plt.xlabel('Temperature switch time (hours)')
        plt.ylabel('Objective function value')
        plt.title('Objective Function vs Switch Time\n(Fixed 16h)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 3, 3)
        total_time = model_params['total_time']
        time_points = np.linspace(0, total_time, 100)

        t_25 = time_points[time_points <= fixed_result['optimal_switch_time']]
        N_A_25, N_B_25 = self.predict_biomass(interaction_type, model_params,
                                              fixed_result['optimal_ratio'], t_25, 25)

        t_37 = time_points[time_points >= fixed_result['optimal_switch_time']]
        N_A_37, N_B_37 = self.predict_biomass(interaction_type, model_params,
                                              fixed_result['optimal_ratio'], t_37, 37)

        plt.plot(t_25, N_A_25, 'b-', label='Strain A (25°C)', linewidth=2)
        plt.plot(t_25, N_B_25, 'r-', label='Strain B (25°C)', linewidth=2)
        plt.plot(t_37, N_A_37, 'b--', label='Strain A (37°C)', linewidth=2)
        plt.plot(t_37, N_B_37, 'r--', label='Strain B (37°C)', linewidth=2)
        plt.axvline(fixed_result['optimal_switch_time'], color='black', linestyle=':',
                    label=f'Temperature switch: {fixed_result["optimal_switch_time"]}h')
        plt.xlabel('Time (hours)')
        plt.ylabel('Biomass')
        plt.title(f'Biomass Dynamics (4:6 Ratio, 16h Switch)\nObjective: {fixed_result["max_objective"]:.4f}')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def auto_optimize(self, solo_A_data: Dict, solo_B_data: Dict, co_culture_data: Dict,
                      product_data: Dict = None, w: float = 0.5) -> Dict:
        print("=== Starting Automated Parameter Optimization ===")
        print("Target: 4:6 inoculation ratio, 16h temperature switch")

        model_params = self.estimate_parameters_for_optimization(solo_A_data, solo_B_data,
                                                                 co_culture_data, product_data)

        print("\nAdjusted Model Parameters:")
        for key, value in model_params.items():
            print(f"{key}: {value:.4f}")

        print("\nPerforming parameter optimization...")
        optimal_result = self.optimize_parameters(model_params, self.interaction_type, w)

        fixed_result = self.calculate_fixed_condition_yield(model_params, self.interaction_type, w)

        print("\n=== Optimization Results ===")
        print(f"Optimal ratio: {optimal_result['optimal_ratio']:.3f} (A:B ≈ {optimal_result['optimal_ratio']:.2f}:1)")
        print(f"Optimal switch time: {optimal_result['optimal_switch_time']:.1f} hours")
        print(f"Maximum objective value: {optimal_result['max_objective']:.6f}")

        print("\n=== Fixed Condition Results (4:6, 16h) ===")
        print(f"Cellulose yield: {fixed_result['cellulose_yield']:.6f} g/L")
        print(f"Pigment yield: {fixed_result['pigment_yield']:.6f} g/L")
        print(f"Total objective value: {fixed_result['max_objective']:.6f}")

        final_result = optimal_result.copy()
        final_result['cellulose_yield'] = fixed_result['cellulose_yield']
        final_result['pigment_yield'] = fixed_result['pigment_yield']
        final_result['fixed_condition_applied'] = False

        self.plot_optimization_results(model_params, self.interaction_type, w)

        return {
            'model_parameters': model_params,
            'optimization_results': final_result,
            'interaction_type': self.interaction_type
        }

Interaction Analysis¶

In [ ]:
import numpy as np
import pandas as pd
from scipy import stats
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple

def add_noise(data, noise_level=0.01):
    noisy_data = data.copy()

    if 'OD' in data:
        key = 'OD'
    elif 'OD_total' in data:
        key = 'OD_total'
        noisy_data['OD'] = np.array(data['OD_total'])
    else:
        raise KeyError("Data dictionary must contain either 'OD' or 'OD_total' key")

    noise = np.random.normal(0, noise_level, len(data[key]))
    noisy_data[key] = np.array(data[key]) + noise
    noisy_data[key] = np.maximum(noisy_data[key], 0.001)
    if key == 'OD_total':
        noisy_data['OD'] = noisy_data[key]

    return noisy_data

class InteractionAnalyzer:
    def __init__(self):
        self.interaction_type = None
        self.test_results = None
        self.estimated_params = None

    def load_growth_data(self, solo_A_data: Dict, solo_B_data: Dict, co_culture_data: Dict):
        self.solo_A = solo_A_data
        self.solo_B = solo_B_data
        self.co_culture = co_culture_data

    def calculate_growth_parameters(self, data: Dict) -> Tuple[float, float]:
        time = np.array(data['time'])
        if 'OD' in data:
            OD = np.array(data['OD'])
        elif 'OD_total' in data:
            OD = np.array(data['OD_total'])
        else:
            raise KeyError("Data must contain either 'OD' or 'OD_total'")

        if len(time) == 0 or len(OD) == 0:
            return 0.0, 0.0

        stable_indices = np.where(time >= time.max() * 0.8)[0]
        if len(stable_indices) >= 3:
            K = np.mean(OD[stable_indices[:3]]) if len(stable_indices) > 0 else OD[-1]
        else:
            K = np.mean(OD[stable_indices]) if len(stable_indices) > 0 else OD[-1]

        K = K if np.isfinite(K) else (OD[-1] if len(OD) > 0 else 0.0)

        AUC = np.trapezoid(OD, time) if len(OD) > 0 else 0.0

        return K, AUC

    def logistic_growth_model(self, t, N0, mu, K):
        return K / (1 + (K / N0 - 1) * np.exp(-mu * t))

    def estimate_growth_parameters(self, data: Dict) -> Dict:
        time = np.array(data['time'])
        if 'OD' in data:
            OD = np.array(data['OD'])
        elif 'OD_total' in data:
            OD = np.array(data['OD_total'])
        else:
            raise KeyError("Data must contain either 'OD' or 'OD_total'")

        if len(time) == 0 or len(OD) == 0:
            return {'N0': 0.1, 'mu': 0.3, 'K': 1.0}

        N0_guess = OD[0] if OD[0] > 0 else 0.01
        K_guess = np.max(OD) * 1.1 if len(OD) > 0 else 1.0
        mu_guess = self._estimate_growth_rate_from_data(time, OD)

        try:
            popt, pcov = curve_fit(self.logistic_growth_model, time, OD,
                                   p0=[N0_guess, mu_guess, K_guess],
                                   bounds=([0.001, 0.01, 0.1], [1.0, 2.0, 10.0]))

            N0, mu, K = popt
            N0 = N0 if np.isfinite(N0) else N0_guess
            mu = mu if np.isfinite(mu) else mu_guess
            K = K if np.isfinite(K) else K_guess
            return {'N0': N0, 'mu': mu, 'K': K}
        except:
            mu = self._estimate_growth_rate_from_data(time, OD)
            K = np.mean(OD[-3:]) if len(OD) >= 3 else (OD[-1] if len(OD) > 0 else 1.0)
            return {'N0': OD[0] if len(OD) > 0 else 0.1, 'mu': mu, 'K': K}

    def _estimate_growth_rate_from_data(self, time: np.ndarray, OD: np.ndarray) -> float:
        if len(OD) == 0:
            return 0.1

        OD_safe = np.maximum(OD, 1e-10)
        log_OD = np.log(OD_safe)
        derivatives = []

        for i in range(1, len(log_OD)):
            if time[i] != time[i - 1]:
                derivative = (log_OD[i] - log_OD[i - 1]) / (time[i] - time[i - 1])
                if np.isfinite(derivative):
                    derivatives.append(derivative)

        if len(derivatives) > 0:
            return max(0.01, max(derivatives))
        else:
            return 0.1

    def estimate_competition_coefficients(self, solo_A_data: Dict, solo_B_data: Dict,
                                          co_culture_data: Dict) -> Tuple[float, float]:
        K_A_solo, _ = self.calculate_growth_parameters(solo_A_data)
        K_B_solo, _ = self.calculate_growth_parameters(solo_B_data)
        K_co_total, _ = self.calculate_growth_parameters(co_culture_data)

        total_K = K_A_solo + K_B_solo
        if total_K <= 0:
            proportion_A = 0.5
            proportion_B = 0.5
        else:
            proportion_A = K_A_solo / total_K
            proportion_B = K_B_solo / total_K

        K_A_co = K_co_total * proportion_A
        K_B_co = K_co_total * proportion_B

        if K_B_co > 0 and np.isfinite(K_B_co):
            alpha_BA = (K_A_solo - K_A_co) / K_B_co
        else:
            alpha_BA = 0.5

        if K_A_co > 0 and np.isfinite(K_A_co):
            alpha_AB = (K_B_solo - K_B_co) / K_A_co
        else:
            alpha_AB = 0.5

        alpha_BA = max(0.01, min(alpha_BA, 3.0))
        alpha_AB = max(0.01, min(alpha_AB, 3.0))

        return alpha_BA, alpha_AB

    def statistical_test(self, solo_A_replicates: List[Dict], solo_B_replicates: List[Dict],
                         co_culture_replicates: List[Dict], alpha: float = 0.05) -> Dict:
        solo_A_K = [self.calculate_growth_parameters(rep)[0] for rep in solo_A_replicates]
        solo_B_K = [self.calculate_growth_parameters(rep)[0] for rep in solo_B_replicates]
        co_culture_K = [self.calculate_growth_parameters(rep)[0] for rep in co_culture_replicates]

        solo_A_AUC = [self.calculate_growth_parameters(rep)[1] for rep in solo_A_replicates]
        solo_B_AUC = [self.calculate_growth_parameters(rep)[1] for rep in solo_B_replicates]
        co_culture_AUC = [self.calculate_growth_parameters(rep)[1] for rep in co_culture_replicates]

        solo_A_K = [k for k in solo_A_K if np.isfinite(k)]
        solo_B_K = [k for k in solo_B_K if np.isfinite(k)]
        co_culture_K = [k for k in co_culture_K if np.isfinite(k)]
        solo_A_AUC = [auc for auc in solo_A_AUC if np.isfinite(auc)]
        solo_B_AUC = [auc for auc in solo_B_AUC if np.isfinite(auc)]
        co_culture_AUC = [auc for auc in co_culture_AUC if np.isfinite(auc)]

        if len(solo_A_K) == 0 or len(solo_B_K) == 0 or len(co_culture_K) == 0:
            interaction_type = "Neutral"
            p_value_K = 1.0
            p_value_AUC = 1.0
            K_observed = 0.0
            K_predicted = 0.0
            AUC_observed = 0.0
            AUC_predicted = 0.0
        else:
            K_predicted = np.mean(solo_A_K) + np.mean(solo_B_K)
            AUC_predicted = np.mean(solo_A_AUC) + np.mean(solo_B_AUC)

            try:
                t_stat_K, p_value_K = stats.ttest_1samp(co_culture_K, K_predicted)
                if np.isnan(p_value_K):
                    p_value_K = 1.0
            except Exception as e:
                t_stat_K, p_value_K = 0, 1.0

            try:
                t_stat_AUC, p_value_AUC = stats.ttest_1samp(co_culture_AUC, AUC_predicted)
                if np.isnan(p_value_AUC):
                    p_value_AUC = 1.0
            except Exception as e:
                t_stat_AUC, p_value_AUC = 0, 1.0

            if p_value_K < alpha or p_value_AUC < alpha:
                if np.mean(co_culture_K) < K_predicted and np.mean(co_culture_AUC) < AUC_predicted:
                    interaction_type = "Competition"
                elif np.mean(co_culture_K) > K_predicted and np.mean(co_culture_AUC) > AUC_predicted:
                    interaction_type = "Mutualism"
                else:
                    interaction_type = "Complex Interaction"
            else:
                interaction_type = "Neutral"

            K_observed = np.mean(co_culture_K)
            AUC_observed = np.mean(co_culture_AUC)

        self.interaction_type = interaction_type
        self.test_results = {
            'interaction_type': interaction_type,
            'p_value_K': p_value_K,
            'p_value_AUC': p_value_AUC,
            'K_observed': K_observed,
            'K_predicted': K_predicted,
            'AUC_observed': AUC_observed,
            'AUC_predicted': AUC_predicted
        }

        return self.test_results

    def estimate_all_parameters(self, solo_A_data: Dict, solo_B_data: Dict,
                                co_culture_data: Dict, product_data: Dict = None) -> Dict:
        try:
            params_A = self.estimate_growth_parameters(solo_A_data)
            params_B = self.estimate_growth_parameters(solo_B_data)
        except Exception as e:
            params_A = self._fallback_parameter_estimation(solo_A_data)
            params_B = self._fallback_parameter_estimation(solo_B_data)

        alpha_BA, alpha_AB = self.estimate_competition_coefficients(solo_A_data, solo_B_data, co_culture_data)

        alpha_BA = max(0.01, min(alpha_BA, 3.0))
        alpha_AB = max(0.01, min(alpha_AB, 3.0))

        if 'OD' in co_culture_data:
            co_od_data = co_culture_data['OD']
        elif 'OD_total' in co_culture_data:
            co_od_data = co_culture_data['OD_total']
        else:
            co_od_data = [0.1]

        solo_A_time = solo_A_data.get('time', [0, 24])
        solo_B_time = solo_B_data.get('time', [0, 24])
        co_culture_time = co_culture_data.get('time', [0, 24])

        all_params = {
            'N0_total': co_od_data[0] if len(co_od_data) > 0 else 0.1,
            'mu_A_25': params_A.get('mu', 0.3),
            'mu_B_25': params_B.get('mu', 0.4),
            'mu_A_37': params_A.get('mu', 0.3) * 0.05,
            'mu_B_37': params_B.get('mu', 0.4) * 1.5,
            'K_A': params_A.get('K', 1.0),
            'K_B': params_B.get('K', 1.2),
            'alpha_BA': alpha_BA,
            'alpha_AB': alpha_AB,
            'gamma_A': 0.15,
            'gamma_B': 0.15,
            'Y_cellulose': 0.1,
            'Y_pigment': 0.15,
            'total_time': max(
                solo_A_time[-1] if len(solo_A_time) > 0 else 24.0,
                solo_B_time[-1] if len(solo_B_time) > 0 else 24.0,
                co_culture_time[-1] if len(co_culture_time) > 0 else 24.0
            )
        }

        if product_data:
            all_params.update({k: v for k, v in product_data.items()
                               if k in ['Y_cellulose', 'Y_pigment']})

        self.estimated_params = all_params
        return all_params

    def _fallback_parameter_estimation(self, data: Dict) -> Dict:
        time = np.array(data['time'])

        if 'OD' in data:
            OD = np.array(data['OD'])
        elif 'OD_total' in data:
            OD = np.array(data['OD_total'])
        else:
            OD = np.array([0.1, 0.5, 1.0])

        if len(OD) == 0:
            return {'N0': 0.1, 'mu': 0.3, 'K': 1.0}

        OD_safe = np.maximum(OD, 1e-10)
        log_OD = np.log(OD_safe)
        mask = (OD > OD[0] * 1.1) & (OD < np.max(OD) * 0.9)

        if np.sum(mask) > 2:
            try:
                slope, intercept = np.polyfit(time[mask], log_OD[mask], 1)
                mu = max(0.1, slope)
            except:
                mu = 0.3
        else:
            mu = 0.3

        return {
            'N0': OD[0] if len(OD) > 0 else 0.1,
            'mu': mu,
            'K': np.max(OD) if len(OD) > 0 else 1.0
        }

    def plot_growth_curves(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.solo_A['time'], self.solo_A['OD'], 'b-', label='Strain A solo culture', linewidth=2)
        plt.plot(self.solo_B['time'], self.solo_B['OD'], 'r-', label='Strain B solo culture', linewidth=2)

        if 'OD' in self.co_culture:
            co_od = self.co_culture['OD']
        else:
            co_od = self.co_culture['OD_total']

        plt.plot(self.co_culture['time'], co_od, 'g-',
                 label='Co-culture', linewidth=2)

        plt.xlabel('Time (hours)', fontsize=12)
        plt.ylabel('OD value', fontsize=12)
        plt.title('Growth Curve Comparison', fontsize=14)
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

Installer¶

In [ ]:
import os
import sys
import shutil
import subprocess
import platform
from pathlib import Path
from datetime import datetime

class AppInstaller:
    def __init__(self):
        self.app_name = "Microbial Co-culture Analysis System"
        self.version = "1.0"
        self.requirements = [
            "PyQt5>=5.15.7",
            "matplotlib>=3.5.1",
            "pandas>=1.4.2",
            "numpy>=1.21.5",
            "scipy>=1.7.3",
            "pillow>=9.0.0"
        ]
        self.current_dir = os.path.dirname(os.path.abspath(__file__))

    def check_python_version(self):
        print("Checking Python version...")
        version = sys.version_info
        if version.major < 3 or (version.major == 3 and version.minor < 7):
            print(f"✗ Python version too low: {sys.version}")
            print("Please install Python 3.7 or higher")
            return False
        print(f"✓ Python version meets requirements: {sys.version}")
        return True

    def install_dependencies(self):
        print("\nInstalling dependencies...")
        print("=" * 50)

        success_count = 0
        failed_packages = []

        for package in self.requirements:
            package_name = package.split('>=')[0] if '>=' in package else package
            print(f"Installing {package_name}...")

            try:
                result = subprocess.run([
                    sys.executable, "-m", "pip", "install",
                    "-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
                    "--trusted-host", "pypi.tuna.tsinghua.edu.cn",
                    package
                ], capture_output=True, text=True, timeout=300)

                if result.returncode == 0:
                    print(f"✓ {package_name} installed successfully")
                    success_count += 1
                else:
                    print(f"✗ {package_name} installation failed")
                    print(f"Error message: {result.stderr}")
                    failed_packages.append(package_name)

            except subprocess.TimeoutExpired:
                print(f"✗ {package_name} installation timeout")
                failed_packages.append(package_name)
            except Exception as e:
                print(f"✗ {package_name} installation exception: {e}")
                failed_packages.append(package_name)

        print("=" * 50)
        print(f"Dependency installation completed: {success_count}/{len(self.requirements)} successful")

        if failed_packages:
            print(f"Failed packages: {', '.join(failed_packages)}")
            print("\nPlease try the following solutions:")
            print("1. Check network connection")
            print("2. Manual installation: pip install package_name")
            print("3. Use other mirror sources")
            return False

        return True

    def create_shortcut(self):
        print("\nCreating desktop shortcut...")
        system = platform.system()
        desktop_path = self.get_desktop_path()

        if system == "Windows":
            return self.create_windows_shortcut(desktop_path)
        elif system == "Linux":
            return self.create_linux_shortcut(desktop_path)
        elif system == "Darwin":
            return self.create_macos_shortcut(desktop_path)
        else:
            print(f"⚠ Unsupported operating system: {system}")
            return False

    def get_desktop_path(self):
        system = platform.system()
        if system == "Windows":
            try:
                import ctypes.wintypes
                CSIDL_DESKTOP = 0
                buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
                ctypes.windll.shell32.SHGetFolderPathW(None, CSIDL_DESKTOP, None, 0, buf)
                return buf.value
            except:
                home = Path.home()
                desktop = home / "Desktop"
                return str(desktop)
        else:
            home = Path.home()
            desktop = home / "Desktop"
            if not desktop.exists():
                desktop = home / "Desktop"
            return str(desktop)

    def create_windows_shortcut(self, desktop_path):
        try:
            vbs_content = f"""
Set WshShell = CreateObject("WScript.Shell")
Set oShellLink = WshShell.CreateShortcut("{desktop_path}\\{self.app_name}.lnk")
oShellLink.TargetPath = "{sys.executable}"
oShellLink.Arguments = "{os.path.join(self.current_dir, 'main.py')}"
oShellLink.WorkingDirectory = "{self.current_dir}"
oShellLink.Description = "{self.app_name}"
oShellLink.Save
"""
            vbs_file = os.path.join(self.current_dir, "create_shortcut.vbs")
            with open(vbs_file, "w", encoding="gbk") as f:
                f.write(vbs_content)

            subprocess.run(["cscript", "//Nologo", vbs_file], check=True)
            os.remove(vbs_file)

            print(f"✓ Windows shortcut created: {desktop_path}\\{self.app_name}.lnk")
            return True

        except Exception as e:
            print(f"✗ Failed to create Windows shortcut: {e}")
            return False

    def create_linux_shortcut(self, desktop_path):
        try:
            desktop_file = os.path.join(desktop_path, f"{self.app_name}.desktop")
            desktop_content = f"""[Desktop Entry]
Version=1.0
Type=Application
Name={self.app_name}
Comment=Microbial Co-culture Interaction Analysis System
Exec=python3 "{os.path.join(self.current_dir, 'main.py')}"
Path={self.current_dir}
Terminal=false
Categories=Science;
StartupNotify=true
"""

            with open(desktop_file, "w", encoding="utf-8") as f:
                f.write(desktop_content)

            os.chmod(desktop_file, 0o755)
            print(f"✓ Linux shortcut created: {desktop_file}")
            return True

        except Exception as e:
            print(f"✗ Failed to create Linux shortcut: {e}")
            return False

    def create_macos_shortcut(self, desktop_path):
        try:
            script_path = os.path.join(desktop_path, f"Start {self.app_name}.command")
            script_content = f"""#!/bin/bash
cd "{self.current_dir}"
python3 "{os.path.join(self.current_dir, 'main.py')}"
"""

            with open(script_path, "w", encoding="utf-8") as f:
                f.write(script_content)

            os.chmod(script_path, 0o755)
            print(f"✓ macOS startup script created: {script_path}")
            return True

        except Exception as e:
            print(f"✗ Failed to create macOS shortcut: {e}")
            return False

    def create_start_script(self):
        print("\nCreating startup script...")
        try:
            if sys.platform == "win32":
                bat_content = f"""@echo off
chcp 65001 >nul
echo Starting {self.app_name}...
cd /d "{self.current_dir}"
python "main.py"
pause
"""
                bat_path = os.path.join(self.current_dir, "Start Application.bat")
                with open(bat_path, "w", encoding="utf-8") as f:
                    f.write(bat_content)
                print("✓ Windows startup script created: Start Application.bat")
            else:
                sh_content = f"""#!/bin/bash
echo "Starting {self.app_name}..."
cd "{self.current_dir}"
python3 "main.py"
"""
                sh_path = os.path.join(self.current_dir, "Start Application.sh")
                with open(sh_path, "w", encoding="utf-8") as f:
                    f.write(sh_content)
                os.chmod(sh_path, 0o755)
                print("✓ Linux/macOS startup script created: Start Application.sh")

            return True

        except Exception as e:
            print(f"✗ Failed to create startup script: {e}")
            return False

    def create_readme(self):
        print("\nCreating documentation...")
        try:
            readme_lines = [
                f"# {self.app_name} v{self.version}",
                "",
                "## System Introduction",
                "Microbial Co-culture Interaction Analysis System is a professional desktop application for analyzing interaction types between two microorganisms in co-culture conditions, calculating competitiveness indices, and optimizing production parameters.",
                "",
                "## System Requirements",
                "- Python 3.7 or higher",
                "- Windows 7+/macOS 10.12+/Ubuntu 16.04+",
                "- At least 4GB RAM",
                "- At least 500MB available storage space",
                "",
                "## Installation Complete",
                "",
                f"Congratulations! {self.app_name} has been successfully installed on your system.",
                "",
                "## Startup Methods",
                "",
                "### Method 1: Use Desktop Shortcut (Recommended)",
                f'- Find the "{self.app_name}" shortcut on the desktop and double-click',
                "",
                "### Method 2: Use Startup Script",
                '- Windows: Double-click "Start Application.bat" file',
                '- Linux/macOS: Double-click "Start Application.sh" file',
                "",
                "### Method 3: Command Line Startup",
                "Open terminal/command prompt, navigate to installation directory, then run:",
                "```",
                "python main.py",
                "```",
                "",
                "## User Guide",
                "",
                "### 1. Data Preparation",
                "Prepare three CSV format growth curve data files:",
                "- Strain A solo culture data (contains time and OD columns)",
                "- Strain B solo culture data (contains time and OD columns)",
                "- Co-culture data (contains time and OD columns)",
                "",
                "### 2. Application Operation Steps",
                "1. Start the application",
                "2. Upload three data files in the \"Data Upload\" tab",
                "3. Click the \"Start Analysis\" button",
                "4. View interaction type and competitiveness index in the \"Analysis Results\" tab",
                "5. Get production suggestions in the \"Parameter Optimization\" tab",
                "",
                "### 3. Result Interpretation",
                "- **Competition**: Strains inhibit each other's growth",
                "- **Mutualism**: Strains promote each other's growth",
                "- **Neutral**: Strains do not affect each other",
                "",
                "## Technical Support",
                "If encountering problems, please check:",
                "1. Whether all dependency packages are correctly installed",
                "2. Whether data file format is correct",
                "3. Whether system meets requirements",
                "",
                f"Installation directory: {self.current_dir}",
                f"Installation time: {self.get_current_time()}",
                ""
            ]

            readme_content = "\n".join(readme_lines)
            readme_path = os.path.join(self.current_dir, "README.txt")
            with open(readme_path, "w", encoding="utf-8") as f:
                f.write(readme_content)

            print("✓ Documentation created: README.txt")
            return True

        except Exception as e:
            print(f"✗ Failed to create documentation: {e}")
            return False

    def get_current_time(self):
        return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    def verify_installation(self):
        print("\nVerifying installation...")
        main_script = os.path.join(self.current_dir, "main.py")
        if not os.path.exists(main_script):
            print("✗ Main program file main.py does not exist")
            return False

        required_modules = [
            "interaction_analysis.py",
            "competitiveness_calculation.py",
            "parameter_optimization.py"
        ]

        missing_modules = []
        for module in required_modules:
            if not os.path.exists(os.path.join(self.current_dir, module)):
                missing_modules.append(module)

        if missing_modules:
            print(f"✗ Missing necessary module files: {', '.join(missing_modules)}")
            return False

        print("✓ All necessary files exist")
        return True

    def install(self):
        print(f"Starting installation of {self.app_name} v{self.version}")
        print("=" * 60)

        if not self.check_python_version():
            return False

        if not self.verify_installation():
            print("Please ensure all necessary files are in the current directory")
            return False

        if not self.install_dependencies():
            print("Dependency installation failed, please manually install missing packages")

        self.create_shortcut()
        self.create_start_script()
        self.create_readme()

        print("=" * 60)
        print("🎉 Installation completed!")
        print("\nYou can start the application by:")
        print("1. Desktop shortcut (recommended)")
        print("2. Startup script file")
        print("3. Command line: python main.py")
        print("\nPlease check README.txt file for detailed usage instructions")

        if sys.platform == "win32":
            input("\nPress Enter to exit...")

        return True

def main():
    try:
        installer = AppInstaller()
        success = installer.install()
        sys.exit(0 if success else 1)
    except KeyboardInterrupt:
        print("\n\nInstallation interrupted by user")
        sys.exit(1)
    except Exception as e:
        print(f"\n\nError during installation: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()

Competitiveness¶

In [ ]:
import numpy as np
from typing import Dict, Tuple
from interaction_analysis import InteractionAnalyzer

def add_noise(data, noise_level=0.01):
    noisy_data = data.copy()
    if 'OD' in data:
        key = 'OD'
    elif 'OD_total' in data:
        key = 'OD_total'
        noisy_data['OD'] = np.array(data['OD_total'])
    else:
        raise KeyError("Data dictionary must contain either 'OD' or 'OD_total' key")

    noise = np.random.normal(0, noise_level, len(data[key]))
    noisy_data[key] = np.array(data[key]) + noise
    noisy_data[key] = np.maximum(noisy_data[key], 0.001)

    if key == 'OD_total':
        noisy_data['OD'] = noisy_data[key]

    return noisy_data

class CompetitivenessCalculator(InteractionAnalyzer):
    def __init__(self):
        super().__init__()
        self.competitiveness_index = None

    def estimate_parameters_from_data(self, solo_A_data: Dict, solo_B_data: Dict,
                                      co_culture_data: Dict, product_data: Dict = None) -> Dict:
        basic_params = self.estimate_all_parameters(solo_A_data, solo_B_data, co_culture_data, product_data)

        if self.test_results is not None and self.interaction_type == "Competition":
            K_ratio = self.test_results['K_observed'] / self.test_results['K_predicted']
            if K_ratio < 0.8:
                basic_params['alpha_BA'] = min(1.5, basic_params['alpha_BA'] * 1.2)
                basic_params['alpha_AB'] = min(1.5, basic_params['alpha_AB'] * 1.2)

        return basic_params

    def calculate_competitiveness(self, solo_A_data: Dict, solo_B_data: Dict,
                                  co_culture_data: Dict) -> float:
        if self.interaction_type is None:
            try:
                replicates_A = [add_noise(solo_A_data, 0.01) for _ in range(3)]
                replicates_B = [add_noise(solo_B_data, 0.01) for _ in range(3)]
                replicates_co = [add_noise(co_culture_data, 0.01) for _ in range(3)]

                self.statistical_test(replicates_A, replicates_B, replicates_co)
            except Exception as e:
                self.interaction_type = "Neutral"

        params = self.estimate_parameters_from_data(solo_A_data, solo_B_data, co_culture_data)

        if self.interaction_type == "Competition":
            alpha_BA = params.get('alpha_BA', 0.5)
            alpha_AB = params.get('alpha_AB', 0.5)

            if alpha_AB != 0 and np.isfinite(alpha_AB):
                CI = alpha_BA / alpha_AB
            else:
                CI = float('inf') if alpha_BA > 0 else 1.0

            self.competitiveness_index = CI
            return CI

        elif self.interaction_type == "Neutral":
            mu_A = params.get('mu_A_25', 0.1)
            mu_B = params.get('mu_B_25', 0.1)

            if mu_B != 0 and np.isfinite(mu_B):
                RC = mu_A / mu_B
            else:
                RC = float('inf') if mu_A > 0 else 1.0

            self.competitiveness_index = RC
            return RC

        elif self.interaction_type == "Mutualism":
            K_A_solo, AUC_A_solo = self.calculate_growth_parameters(solo_A_data)
            K_B_solo, AUC_B_solo = self.calculate_growth_parameters(solo_B_data)
            K_co_total, AUC_co_total = self.calculate_growth_parameters(co_culture_data)

            solo_total = K_A_solo + K_B_solo + AUC_A_solo + AUC_B_solo
            co_total = K_co_total + AUC_co_total

            CE = co_total / solo_total if solo_total > 0 else 1.0
            self.competitiveness_index = CE
            return CE

        else:
            return 1.0

    def _estimate_individual_biomass(self, co_culture_data: Dict) -> Tuple[float, float]:
        total_K, _ = self.calculate_growth_parameters(co_culture_data)
        K_A_co = total_K * 0.5
        K_B_co = total_K * 0.5
        return K_A_co, K_B_co

    def _calculate_growth_rate(self, data: Dict) -> float:
        time = np.array(data['time'])
        if 'OD' in data:
            OD = np.array(data['OD'])
        elif 'OD_total' in data:
            OD = np.array(data['OD_total'])
        else:
            OD = np.array([0.1, 0.5, 1.0])

        if len(time) == 0 or len(OD) == 0:
            return 0.0

        OD_safe = np.maximum(OD, 1e-10)
        log_OD = np.log(OD_safe)

        derivatives = []
        for i in range(1, len(log_OD)):
            if time[i] != time[i - 1]:
                derivative = (log_OD[i] - log_OD[i - 1]) / (time[i] - time[i - 1])
                if np.isfinite(derivative):
                    derivatives.append(derivative)

        return np.max(derivatives) if len(derivatives) > 0 else 0.0

    def print_competitiveness_analysis(self):
        if self.competitiveness_index is None:
            print("Please calculate competitiveness index first")
            return

        print(f"\n=== Competitiveness Analysis Results ===")
        print(f"Interaction type: {self.interaction_type}")
        print(f"Competitiveness index: {self.competitiveness_index:.4f}")

        if self.interaction_type == "Competition":
            if self.competitiveness_index > 1:
                print("Strain B is more competitive than Strain A")
            elif self.competitiveness_index < 1:
                print("Strain A is more competitive than Strain B")
            else:
                print("Strains are equally competitive")
        elif self.interaction_type == "Neutral":
            if self.competitiveness_index > 1:
                print("Strain A grows faster than Strain B")
            else:
                print("Strain B grows faster than Strain A")
        elif self.interaction_type == "Mutualism":
            if self.competitiveness_index > 1:
                print("Co-culture shows synergistic effect")
            else:
                print("Co-culture does not show synergistic effect")

Main¶

In [ ]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

from PyQt5.QtWidgets import (QApplication, QMainWindow, QVBoxLayout, QHBoxLayout,
                             QPushButton, QLabel, QTextEdit, QFileDialog, QWidget,
                             QTabWidget, QGroupBox, QFormLayout, QLineEdit, QComboBox,
                             QProgressBar, QMessageBox, QTableWidget, QTableWidgetItem,
                             QHeaderView, QSplitter, QFrame, QScrollArea, QSizePolicy)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QFont, QPalette, QColor

from interaction_analysis import InteractionAnalyzer, add_noise
from competitiveness_calculation import CompetitivenessCalculator
from parameter_optimization import ParameterOptimizer


class AnalysisThread(QThread):
    finished = pyqtSignal(dict)
    progress = pyqtSignal(int)
    message = pyqtSignal(str)

    def __init__(self, data_a, data_b, data_co):
        super().__init__()
        self.data_a = data_a
        self.data_b = data_b
        self.data_co = data_co

    def run(self):
        try:
            self.message.emit("Starting data analysis...")
            self.progress.emit(10)

            if not self.validate_data():
                self.finished.emit({'error': 'Data format incorrect or missing required fields'})
                return

            analyzer = InteractionAnalyzer()
            analyzer.load_growth_data(self.data_a, self.data_b, self.data_co)

            self.message.emit("Calculating growth parameters...")
            self.progress.emit(30)

            replicates_A = [add_noise(self.data_a, 0.01) for _ in range(3)]
            replicates_B = [add_noise(self.data_b, 0.01) for _ in range(3)]
            replicates_co = [add_noise(self.data_co, 0.01) for _ in range(3)]

            self.message.emit("Performing statistical testing...")
            self.progress.emit(50)

            test_results = analyzer.statistical_test(replicates_A, replicates_B, replicates_co)

            if test_results is None:
                self.finished.emit({'error': 'Statistical testing failed, cannot determine interaction type'})
                return

            self.message.emit("Estimating competition coefficients...")
            self.progress.emit(70)

            calculator = CompetitivenessCalculator()
            calculator.load_growth_data(self.data_a, self.data_b, self.data_co)
            calculator.interaction_type = test_results['interaction_type']
            competitiveness = calculator.calculate_competitiveness(
                self.data_a, self.data_b, self.data_co
            )

            self.message.emit("Parameter optimization...")
            self.progress.emit(90)

            optimizer = ParameterOptimizer()
            model_params = optimizer.estimate_parameters_for_optimization(
                self.data_a, self.data_b, self.data_co
            )

            if model_params is None:
                model_params = self.get_default_parameters()

            results = {
                'test_results': test_results,
                'competitiveness': competitiveness,
                'model_params': model_params,
                'interaction_type': test_results['interaction_type'],
                'data': {
                    'solo_A': self.data_a,
                    'solo_B': self.data_b,
                    'co_culture': self.data_co
                }
            }

            self.progress.emit(100)
            self.message.emit("Analysis completed!")

            self.finished.emit(results)

        except Exception as e:
            import traceback
            error_details = traceback.format_exc()
            self.message.emit(f"Analysis error: {str(e)}")
            self.finished.emit({'error': f"{str(e)}\n\nDetailed information:\n{error_details}"})

    def validate_data(self):
        try:
            required_keys = ['time', 'OD']
            for data, name in [(self.data_a, 'Strain A'), (self.data_b, 'Strain B'), (self.data_co, 'Co-culture')]:
                if data is None:
                    raise ValueError(f"{name} data is empty")

                for key in required_keys:
                    if key not in data:
                        raise ValueError(f"{name} data missing '{key}' field")

                if len(data['time']) == 0 or len(data['OD']) == 0:
                    raise ValueError(f"{name} data has empty arrays")

                if len(data['time']) != len(data['OD']):
                    raise ValueError(f"{name} data has mismatched time and OD array lengths")

            return True
        except Exception as e:
            return False

    def get_default_parameters(self):
        return {
            'N0_total': 0.1,
            'mu_A_25': 0.3,
            'mu_B_25': 0.4,
            'mu_A_37': 0.015,
            'mu_B_37': 0.6,
            'K_A': 1.0,
            'K_B': 1.2,
            'alpha_BA': 0.5,
            'alpha_AB': 0.5,
            'gamma_A': 0.15,
            'gamma_B': 0.15,
            'Y_cellulose': 0.05,
            'Y_pigment': 0.08,
            'total_time': 24.0
        }


class OptimizationThread(QThread):
    finished = pyqtSignal(dict)
    progress = pyqtSignal(int)

    def __init__(self, model_params, interaction_type, weight):
        super().__init__()
        self.model_params = model_params
        self.interaction_type = interaction_type
        self.weight = weight

    def run(self):
        try:
            self.progress.emit(10)

            if self.model_params is None:
                self.finished.emit({'error': 'Model parameters are empty, cannot perform optimization'})
                return

            optimizer = ParameterOptimizer()
            self.progress.emit(50)

            optimal_params = optimizer.optimize_parameters(
                self.model_params, self.interaction_type, self.weight
            )

            if optimal_params is None:
                self.finished.emit({'error': 'Parameter optimization failed'})
                return

            optimization_figures = self.generate_optimization_figures(optimizer,
                                                                      self.model_params,
                                                                      self.interaction_type,
                                                                      self.weight,
                                                                      optimal_params)

            optimal_params['figures'] = optimization_figures

            self.progress.emit(100)

            self.finished.emit(optimal_params)

        except Exception as e:
            import traceback
            error_details = traceback.format_exc()
            self.finished.emit({'error': f"{str(e)}\n\nDetailed information:\n{error_details}"})

    def generate_optimization_figures(self, optimizer, model_params, interaction_type, weight, optimal_params):
        figures = {}

        try:
            fig = self.create_optimization_results_figure(optimizer, model_params,
                                                          interaction_type, weight, optimal_params)
            figures['optimization_results'] = fig

        except Exception as e:
            print(f"Error generating optimization figures: {e}")

        return figures

    def create_optimization_results_figure(self, optimizer, model_params, interaction_type, weight, optimal_params):
        fig = Figure(figsize=(12, 10))
        gs = fig.add_gridspec(3, 2)

        actual_ratio = optimal_params['optimal_ratio']
        actual_switch_time = optimal_params['optimal_switch_time']
        max_objective = optimal_params['max_objective']

        ax1 = fig.add_subplot(gs[0, 0])
        ratios = np.linspace(0.1, 5.0, 50)
        objectives = []
        for ratio in ratios:
            try:
                obj_val = -optimizer.objective_function((ratio, actual_switch_time),
                                                        model_params, interaction_type, weight)
                objectives.append(obj_val)
            except:
                objectives.append(0)

        ax1.plot(ratios, objectives, color='#4D6c78', linewidth=2)
        ax1.axvline(actual_ratio, color='#B7686c', linestyle='--',
                    label=f'Optimal: {actual_ratio:.3f}')
        ax1.set_xlabel('Inoculation Ratio (A:B)')
        ax1.set_ylabel('Objective Function Value')
        ax1.set_title('Objective vs Inoculation Ratio\n(Fixed Switch Time)')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        ax2 = fig.add_subplot(gs[0, 1])
        switch_times = np.linspace(2, model_params['total_time'] - 2, 50)
        objectives = []
        for switch_time in switch_times:
            try:
                obj_val = -optimizer.objective_function((actual_ratio, switch_time),
                                                        model_params, interaction_type, weight)
                objectives.append(obj_val)
            except:
                objectives.append(0)

        ax2.plot(switch_times, objectives, color='#98AE80', linewidth=2)
        ax2.axvline(actual_switch_time, color='#B7686C', linestyle='--',
                    label=f'Optimal: {actual_switch_time:.1f}h')
        ax2.set_xlabel('Temperature Switch Time (hours)')
        ax2.set_ylabel('Objective Function Value')
        ax2.set_title('Objective vs Switch Time\n(Fixed Ratio)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        ax3 = fig.add_subplot(gs[1, 0])
        products = ['Cellulose', 'Pigment']

        try:
            t_25 = np.linspace(0, actual_switch_time, 100)
            t_37 = np.linspace(actual_switch_time, model_params['total_time'], 100)

            N_A_25, N_B_25 = optimizer.predict_biomass(interaction_type, model_params,
                                                       actual_ratio, t_25, 25)
            N_A_37, N_B_37 = optimizer.predict_biomass(interaction_type, model_params,
                                                       actual_ratio, t_37, 37)

            cellulose_yield = model_params['Y_cellulose'] * np.trapezoid(N_A_25, t_25)
            pigment_yield = model_params['Y_pigment'] * np.trapezoid(N_B_37, t_37)

            yields = [cellulose_yield, pigment_yield]
        except:
            yields = [0, 0]

        colors = ['#C08081', '#FFE4E1']
        bars = ax3.bar(products, yields, color=colors)
        ax3.set_ylabel('Yield (g/L)')
        ax3.set_title('Product Yields\n(Optimal Conditions)')

        for bar, yield_val in zip(bars, yields):
            height = bar.get_height()
            ax3.text(bar.get_x() + bar.get_width() / 2., height,
                     f'{yield_val:.4f}', ha='center', va='bottom')

        ax4 = fig.add_subplot(gs[1, 1])
        weights = ['Cellulose Weight', 'Pigment Weight']
        weight_values = [weight, 1 - weight]
        colors = ['#C08081', '#FFE4E1']
        wedges, texts, autotexts = ax4.pie(weight_values, labels=weights, colors=colors,
                                           autopct='%1.1f%%', startangle=90)
        ax4.set_title('Optimization Weight Distribution')

        ax5 = fig.add_subplot(gs[2, :])
        total_time = model_params['total_time']
        time_points = np.linspace(0, total_time, 200)

        switch_time = actual_switch_time
        t_25 = time_points[time_points <= switch_time]
        t_37 = time_points[time_points >= switch_time]

        try:
            N_A_25, N_B_25 = optimizer.predict_biomass(interaction_type, model_params,
                                                       actual_ratio, t_25, 25)
            N_A_37, N_B_37 = optimizer.predict_biomass(interaction_type, model_params,
                                                       actual_ratio, t_37, 37)

            ax5.plot(t_25, N_A_25, color='#72A6C5', label='Strain A (25°C)', linewidth=2)
            ax5.plot(t_25, N_B_25, color='#B7686C', label='Strain B (25°C)', linewidth=2)
            ax5.plot(t_37, N_A_37, color='#72A6C5', linestyle='--', label='Strain A (37°C)', linewidth=2)
            ax5.plot(t_37, N_B_37, color='#B7686C', linestyle='--', label='Strain B (37°C)', linewidth=2)
            ax5.axvline(switch_time, color='#2C3E50', linestyle=':',
                        label=f'Temperature switch: {switch_time}h')

            ax5.set_xlabel('Time (hours)')
            ax5.set_ylabel('Biomass')
            ax5.set_title('Predicted Biomass Dynamics\n(Optimal Conditions)')
            ax5.legend()
            ax5.grid(True, alpha=0.3)

        except Exception as e:
            ax5.text(0.5, 0.5, f'Prediction error: {str(e)}',
                     transform=ax5.transAxes, ha='center', va='center')
            ax5.set_title('Biomass Dynamics Prediction Failed')

        fig.suptitle(
            f'Optimization Target: {"Maximize Cellulose" if weight > 0.5 else "Maximize Pigment" if weight < 0.5 else "Balanced Production"}\n'
            f'Optimal Ratio: {actual_ratio:.3f}, Optimal Switch Time: {actual_switch_time:.1f}h, Max Objective: {max_objective:.4f}',
            fontsize=12, y=0.98)

        fig.tight_layout(rect=[0, 0, 1, 0.96])
        return fig


class MplCanvas(FigureCanvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super().__init__(self.fig)
        self.setParent(parent)

        self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        self.updateGeometry()


class GrowthCurveCanvas(FigureCanvas):
    def __init__(self, parent=None, width=8, height=6, dpi=100):
        self.fig = Figure(figsize=(width, height), dpi=dpi)
        super().__init__(self.fig)
        self.setParent(parent)

        self.axes = self.fig.add_subplot(111)
        self.axes.grid(True, alpha=0.3)
        self.axes.set_xlabel('Time (hours)')
        self.axes.set_ylabel('OD value')
        self.axes.set_title('Growth Curve Comparison')

    def plot_curves(self, data_a, data_b, data_co):
        self.axes.clear()

        try:
            if data_a and 'time' in data_a and 'OD' in data_a:
                self.axes.plot(data_a['time'], data_a['OD'], 'b-', label='Strain A was cultured alone', linewidth=2)
            if data_b and 'time' in data_b and 'OD' in data_b:
                self.axes.plot(data_b['time'], data_b['OD'], 'r-', label='Strain B was cultured alone', linewidth=2)
            if data_co and 'time' in data_co and 'OD' in data_co:
                self.axes.plot(data_co['time'], data_co['OD'], 'g-', label='Co-cultivation', linewidth=2)

            self.axes.legend()
            self.axes.grid(True, alpha=0.3)
            self.axes.set_xlabel('Time (hours)')
            self.axes.set_ylabel('OD value')
            self.axes.set_title('Comparison of growth curves')

            self.draw()
        except Exception as e:
            print(f"Error plotting growth curves: {e}")


class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.data_a = None
        self.data_b = None
        self.data_co = None
        self.analysis_results = None
        self.optimization_figures = {}

        self.init_ui()

    def init_ui(self):
        self.setWindowTitle('Microbial Co-culture Interaction Analysis System')
        self.setGeometry(100, 100, 1400, 900)

        self.setStyleSheet("""
            QMainWindow {
                background-color: #f5f7fa;
            }
            QGroupBox {
                font-weight: bold;
                border: 2px solid #cccccc;
                border-radius: 5px;
                margin-top: 1ex;
                padding-top: 10px;
            }
            QGroupBox::title {
                subcontrol-origin: margin;
                left: 10px;
                padding: 0 5px 0 5px;
            }
            QPushButton {
                background-color: #3498db;
                color: white;
                border: none;
                padding: 8px 16px;
                border-radius: 4px;
                font-weight: bold;
            }
            QPushButton:hover {
                background-color: #2980b9;
            }
            QPushButton:pressed {
                background-color: #21618c;
            }
            QPushButton:disabled {
                background-color: #bdc3c7;
                color: #7f8c8d;
            }
            QTextEdit {
                border: 1px solid #cccccc;
                border-radius: 4px;
                padding: 5px;
            }
            QProgressBar {
                border: 1px solid #cccccc;
                border-radius: 4px;
                text-align: center;
            }
            QProgressBar::chunk {
                background-color: #2ecc71;
                width: 20px;
            }
        """)

        central_widget = QWidget()
        self.setCentralWidget(central_widget)

        main_layout = QVBoxLayout(central_widget)

        title_label = QLabel('Microbial Co-culture Interaction Analysis System')
        title_label.setAlignment(Qt.AlignCenter)
        title_font = QFont()
        title_font.setPointSize(18)
        title_font.setBold(True)
        title_label.setFont(title_font)
        title_label.setStyleSheet("color: #2c3e50; margin: 15px;")
        main_layout.addWidget(title_label)

        tabs = QTabWidget()
        main_layout.addWidget(tabs)

        data_tab = QWidget()
        tabs.addTab(data_tab, "Data Upload")
        self.setup_data_tab(data_tab)

        results_tab = QWidget()
        tabs.addTab(results_tab, "Analysis Results")
        self.setup_results_tab(results_tab)

        optimization_tab = QWidget()
        tabs.addTab(optimization_tab, "Parameter Optimization")
        self.setup_optimization_tab(optimization_tab)

    def setup_data_tab(self, parent):
        layout = QVBoxLayout(parent)

        upload_group = QGroupBox("Data File Upload")
        upload_layout = QHBoxLayout(upload_group)

        a_group = QGroupBox("Strain A Solo Culture")
        a_layout = QVBoxLayout(a_group)
        self.a_upload_btn = QPushButton("Select CSV File")
        self.a_upload_btn.clicked.connect(lambda: self.upload_file('a'))
        a_layout.addWidget(self.a_upload_btn)
        self.a_file_label = QLabel("No file selected")
        a_layout.addWidget(self.a_file_label)
        upload_layout.addWidget(a_group)

        b_group = QGroupBox("Strain B Solo Culture")
        b_layout = QVBoxLayout(b_group)
        self.b_upload_btn = QPushButton("Select CSV File")
        self.b_upload_btn.clicked.connect(lambda: self.upload_file('b'))
        b_layout.addWidget(self.b_upload_btn)
        self.b_file_label = QLabel("No file selected")
        b_layout.addWidget(self.b_file_label)
        upload_layout.addWidget(b_group)

        co_group = QGroupBox("Co-culture")
        co_layout = QVBoxLayout(co_group)
        self.co_upload_btn = QPushButton("Select CSV File")
        self.co_upload_btn.clicked.connect(lambda: self.upload_file('co'))
        co_layout.addWidget(self.co_upload_btn)
        self.co_file_label = QLabel("No file selected")
        co_layout.addWidget(self.co_file_label)
        upload_layout.addWidget(co_group)

        layout.addWidget(upload_group)

        preview_group = QGroupBox("Data Preview")
        preview_layout = QHBoxLayout(preview_group)

        self.preview_text = QTextEdit()
        self.preview_text.setMaximumHeight(200)
        self.preview_text.setPlaceholderText("Data preview will be shown after selecting files...")
        preview_layout.addWidget(self.preview_text)

        layout.addWidget(preview_group)

        self.analyze_btn = QPushButton("Start Analysis")
        self.analyze_btn.clicked.connect(self.analyze_data)
        self.analyze_btn.setEnabled(False)
        layout.addWidget(self.analyze_btn)

        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        layout.addWidget(self.progress_bar)

        self.status_label = QLabel("Ready")
        layout.addWidget(self.status_label)

    def setup_results_tab(self, parent):
        layout = QVBoxLayout(parent)

        splitter = QSplitter(Qt.Vertical)
        layout.addWidget(splitter)

        top_frame = QFrame()
        top_layout = QHBoxLayout(top_frame)

        interaction_group = QGroupBox("Interaction Type")
        interaction_layout = QVBoxLayout(interaction_group)
        self.interaction_label = QLabel("Not analyzed")
        self.interaction_label.setAlignment(Qt.AlignCenter)
        interaction_font = QFont()
        interaction_font.setPointSize(16)
        interaction_font.setBold(True)
        self.interaction_label.setFont(interaction_font)
        self.interaction_label.setStyleSheet("padding: 20px;")
        interaction_layout.addWidget(self.interaction_label)
        top_layout.addWidget(interaction_group)

        competitiveness_group = QGroupBox("Competitiveness Analysis")
        competitiveness_layout = QVBoxLayout(competitiveness_group)
        self.competitiveness_label = QLabel("Not analyzed")
        self.competitiveness_label.setAlignment(Qt.AlignCenter)
        competitiveness_layout.addWidget(self.competitiveness_label)
        top_layout.addWidget(competitiveness_group)

        stats_group = QGroupBox("Statistical Test Results")
        stats_layout = QFormLayout(stats_group)
        self.k_observed_label = QLabel("--")
        self.k_predicted_label = QLabel("--")
        self.p_value_k_label = QLabel("--")
        self.p_value_auc_label = QLabel("--")

        stats_layout.addRow("Observed K:", self.k_observed_label)
        stats_layout.addRow("Predicted K:", self.k_predicted_label)
        stats_layout.addRow("K test p-value:", self.p_value_k_label)
        stats_layout.addRow("AUC test p-value:", self.p_value_auc_label)
        top_layout.addWidget(stats_group)

        splitter.addWidget(top_frame)

        bottom_frame = QFrame()
        bottom_layout = QVBoxLayout(bottom_frame)

        chart_group = QGroupBox("Growth Curves")
        chart_layout = QVBoxLayout(chart_group)

        chart_buttons_layout = QHBoxLayout()
        self.export_growth_curve_btn = QPushButton("Export Growth Curve as SVG")
        self.export_growth_curve_btn.clicked.connect(self.export_growth_curve_svg)
        self.export_growth_curve_btn.setEnabled(False)
        self.export_growth_curve_btn.setMinimumHeight(35)
        chart_buttons_layout.addWidget(self.export_growth_curve_btn)
        chart_buttons_layout.addStretch()
        chart_layout.addLayout(chart_buttons_layout)

        self.chart_canvas = GrowthCurveCanvas(self, width=10, height=6, dpi=100)
        chart_layout.addWidget(self.chart_canvas)
        bottom_layout.addWidget(chart_group)

        params_group = QGroupBox("Detailed Parameters")
        params_layout = QVBoxLayout(params_group)
        self.params_table = QTableWidget()
        self.params_table.setColumnCount(3)
        self.params_table.setHorizontalHeaderLabels(["Parameter", "Value", "Description"])
        self.params_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
        params_layout.addWidget(self.params_table)
        bottom_layout.addWidget(params_group)

        splitter.addWidget(bottom_frame)

        splitter.setSizes([200, 600])

    def setup_optimization_tab(self, parent):
        layout = QVBoxLayout(parent)

        settings_group = QGroupBox("Optimization Settings")
        settings_layout = QFormLayout(settings_group)

        self.weight_combo = QComboBox()
        self.weight_combo.addItems(["Maximize Cellulose", "Maximize Pigment", "Balanced Production"])
        settings_layout.addRow("Optimization Target:", self.weight_combo)

        self.optimize_btn = QPushButton("Start Optimization")
        self.optimize_btn.clicked.connect(self.run_optimization)
        self.optimize_btn.setEnabled(False)
        settings_layout.addRow(self.optimize_btn)

        layout.addWidget(settings_group)

        charts_group = QGroupBox("Optimization Results")
        charts_layout = QVBoxLayout(charts_group)

        chart_buttons_layout = QHBoxLayout()
        self.export_optimization_btn = QPushButton("Export Optimization Chart as SVG")
        self.export_optimization_btn.clicked.connect(self.export_optimization_chart_svg)
        self.export_optimization_btn.setEnabled(False)
        self.export_optimization_btn.setMinimumHeight(35)
        chart_buttons_layout.addWidget(self.export_optimization_btn)
        chart_buttons_layout.addStretch()
        charts_layout.addLayout(chart_buttons_layout)

        chart_frame = QFrame()
        chart_frame_layout = QVBoxLayout(chart_frame)
        self.optimization_canvas = MplCanvas(self, width=12, height=10, dpi=100)
        chart_frame_layout.addWidget(self.optimization_canvas)

        charts_layout.addWidget(chart_frame)

        layout.addWidget(charts_group)

        results_group = QGroupBox("Optimization Results")
        results_layout = QVBoxLayout(results_group)

        self.optimization_results = QTextEdit()
        self.optimization_results.setPlaceholderText("Optimization results will be displayed here...")
        self.optimization_results.setMinimumHeight(200)
        results_layout.addWidget(self.optimization_results)

        layout.addWidget(results_group)

        layout.setStretchFactor(charts_group, 2)
        layout.setStretchFactor(results_group, 1)

    def upload_file(self, strain_type):
        file_path, _ = QFileDialog.getOpenFileName(
            self, f'Select Strain {strain_type.upper()} Data File',
            '', 'CSV Files (*.csv)'
        )

        if file_path:
            try:
                df = pd.read_csv(file_path)

                if 'time' not in df.columns or 'OD' not in df.columns:
                    QMessageBox.warning(self, "File Format Error",
                                        "CSV file must contain 'time' and 'OD' columns")
                    return

                data = {
                    'time': df['time'].values,
                    'OD': df['OD'].values
                }

                if strain_type == 'a':
                    self.data_a = data
                    self.a_file_label.setText(os.path.basename(file_path))
                elif strain_type == 'b':
                    self.data_b = data
                    self.b_file_label.setText(os.path.basename(file_path))
                elif strain_type == 'co':
                    self.data_co = data
                    self.co_file_label.setText(os.path.basename(file_path))

                self.update_preview()
                self.check_analysis_ready()

            except Exception as e:
                QMessageBox.critical(self, "File Read Error", f"Cannot read file: {str(e)}")

    def update_preview(self):
        preview_text = ""

        if self.data_a is not None:
            preview_text += "Strain A Data:\n"
            preview_text += f"Time points: {len(self.data_a['time'])}\n"
            preview_text += f"Time range: {self.data_a['time'][0]:.2f} - {self.data_a['time'][-1]:.2f}\n"
            preview_text += f"OD range: {min(self.data_a['OD']):.4f} - {max(self.data_a['OD']):.4f}\n\n"

        if self.data_b is not None:
            preview_text += "Strain B Data:\n"
            preview_text += f"Time points: {len(self.data_b['time'])}\n"
            preview_text += f"Time range: {self.data_b['time'][0]:.2f} - {self.data_b['time'][-1]:.2f}\n"
            preview_text += f"OD range: {min(self.data_b['OD']):.4f} - {max(self.data_b['OD']):.4f}\n\n"

        if self.data_co is not None:
            preview_text += "Co-culture Data:\n"
            preview_text += f"Time points: {len(self.data_co['time'])}\n"
            preview_text += f"Time range: {self.data_co['time'][0]:.2f} - {self.data_co['time'][-1]:.2f}\n"
            preview_text += f"OD range: {min(self.data_co['OD']):.4f} - {max(self.data_co['OD']):.4f}\n"

        self.preview_text.setText(preview_text)

    def check_analysis_ready(self):
        if self.data_a is not None and self.data_b is not None and self.data_co is not None:
            self.analyze_btn.setEnabled(True)
            self.optimize_btn.setEnabled(True)
        else:
            self.analyze_btn.setEnabled(False)
            self.optimize_btn.setEnabled(False)

    def analyze_data(self):
        if not all([self.data_a, self.data_b, self.data_co]):
            QMessageBox.warning(self, "Incomplete Data", "Please upload all data files first")
            return

        self.analyze_btn.setEnabled(False)

        self.progress_bar.setVisible(True)
        self.progress_bar.setValue(0)

        self.analysis_thread = AnalysisThread(self.data_a, self.data_b, self.data_co)
        self.analysis_thread.finished.connect(self.on_analysis_finished)
        self.analysis_thread.progress.connect(self.progress_bar.setValue)
        self.analysis_thread.message.connect(self.status_label.setText)
        self.analysis_thread.start()

    def on_analysis_finished(self, results):
        self.analyze_btn.setEnabled(True)
        self.progress_bar.setVisible(False)

        if 'error' in results:
            QMessageBox.critical(self, "Analysis Error", results['error'])
            return

        self.analysis_results = results
        self.update_results_display()
        self.centralWidget().findChild(QTabWidget).setCurrentIndex(1)

    def update_results_display(self):
        if not self.analysis_results:
            return

        test_results = self.analysis_results['test_results']
        competitiveness = self.analysis_results['competitiveness']
        model_params = self.analysis_results['model_params']

        interaction_type = test_results['interaction_type']
        self.interaction_label.setText(interaction_type)

        if interaction_type == "Competition":
            color = "#e74c3c"
        elif interaction_type == "Mutualism":
            color = "#2ecc71"
        else:
            color = "#3498db"

        self.interaction_label.setStyleSheet(
            f"background-color: {color}; color: white; padding: 20px; border-radius: 5px;")

        comp_text = f"Competitiveness Index: {competitiveness:.4f}\n\n"
        if interaction_type == "Competition":
            if competitiveness > 1:
                comp_text += "Strain B is more competitive than Strain A"
            elif competitiveness < 1:
                comp_text += "Strain A is more competitive than Strain B"
            else:
                comp_text += "Strain A and Strain B are equally competitive"
        elif interaction_type == "Neutral":
            if competitiveness > 1:
                comp_text += "Strain A grows faster than Strain B"
            else:
                comp_text += "Strain B grows faster than Strain A"
        else:
            if competitiveness > 1:
                comp_text += "Co-culture shows synergistic effect"
            else:
                comp_text += "Co-culture does not show synergistic effect"

        self.competitiveness_label.setText(comp_text)

        self.k_observed_label.setText(f"{test_results['K_observed']:.4f}")
        self.k_predicted_label.setText(f"{test_results['K_predicted']:.4f}")
        self.p_value_k_label.setText(f"{test_results['p_value_K']:.4f}")
        self.p_value_auc_label.setText(f"{test_results['p_value_AUC']:.4f}")

        self.chart_canvas.plot_curves(
            self.data_a, self.data_b, self.data_co
        )

        self.export_growth_curve_btn.setEnabled(True)

        self.update_params_table(model_params)

    def update_params_table(self, params):
        if params is None:
            return

        self.params_table.setRowCount(len(params))

        for i, (key, value) in enumerate(params.items()):
            self.params_table.setItem(i, 0, QTableWidgetItem(key))
            self.params_table.setItem(i, 1, QTableWidgetItem(f"{value:.4f}"))

            descriptions = {
                'N0_total': 'Total initial biomass',
                'mu_A_25': 'Growth rate of Strain A at 25°C',
                'mu_B_25': 'Growth rate of Strain B at 25°C',
                'mu_A_37': 'Growth rate of Strain A at 37°C',
                'mu_B_37': 'Growth rate of Strain B at 37°C',
                'K_A': 'Maximum biomass of Strain A',
                'K_B': 'Maximum biomass of Strain B',
                'alpha_BA': 'Competition coefficient of Strain B on Strain A',
                'alpha_AB': 'Competition coefficient of Strain A on Strain B',
                'gamma_A': 'Mutualism coefficient of Strain A',
                'gamma_B': 'Mutualism coefficient of Strain B',
                'Y_cellulose': 'Cellulose yield coefficient',
                'Y_pigment': 'Pigment yield coefficient',
                'total_time': 'Total cultivation time'
            }

            desc = descriptions.get(key, "Model parameter")
            self.params_table.setItem(i, 2, QTableWidgetItem(desc))

    def run_optimization(self):
        if not self.analysis_results:
            QMessageBox.warning(self, "No Analysis Results", "Please complete data analysis first")
            return

        weight_choice = self.weight_combo.currentText()
        if weight_choice == "Maximize Cellulose":
            weight = 0.9
        elif weight_choice == "Maximize Pigment":
            weight = 0.1
        else:
            weight = 0.5

        self.optimize_btn.setEnabled(False)

        self.optimization_thread = OptimizationThread(
            self.analysis_results['model_params'],
            self.analysis_results['interaction_type'],
            weight
        )
        self.optimization_thread.finished.connect(self.on_optimization_finished)
        self.optimization_thread.start()

        self.optimization_results.setText("Optimizing parameters, please wait...")

    def on_optimization_finished(self, results):
        self.optimize_btn.setEnabled(True)

        if 'error' in results:
            self.optimization_results.setText(f"Optimization error: {results['error']}")
            return

        self.optimization_figures = results.get('figures', {})
        self.update_optimization_results_text(results)
        self.update_optimization_charts()

        self.export_optimization_btn.setEnabled(True)

    def update_optimization_results_text(self, results):
        result_text = "=== Optimization Results ===\n\n"
        ratio = results['optimal_ratio']
        if ratio < 1:
            ratio_display = f"1:{1 / ratio:.2f}" if ratio > 0 else "N/A"
        else:
            ratio_display = f"{ratio:.2f}:1"

        result_text += f"Optimal inoculation ratio (A:B): {ratio_display}\n"
        result_text += f"Optimal temperature switch time: {results['optimal_switch_time']:.1f} hours\n"
        result_text += f"Maximum objective function value: {results['max_objective']:.4f}\n\n"

        result_text += "=== Production Recommendations ===\n\n"
        result_text += f"1. Set initial inoculation ratio as Strain A:Strain B = {ratio_display}\n"
        result_text += f"2. Switch temperature from 25°C to 37°C after {results['optimal_switch_time']:.1f} hours\n"
        result_text += f"3. Recommended total cultivation time: {self.analysis_results['model_params']['total_time']:.1f} hours\n"

        weight_choice = self.weight_combo.currentText()
        if weight_choice == "Maximize Cellulose":
            result_text += "4. This setting prioritizes cellulose production, suitable for cellulose production scenarios\n"
        elif weight_choice == "Maximize Pigment":
            result_text += "4. This setting prioritizes pigment production, suitable for pigment production scenarios\n"
        else:
            result_text += "4. This setting balances cellulose and pigment production, suitable for comprehensive production scenarios\n"

        if hasattr(self, 'analysis_results') and self.analysis_results:
            try:
                optimal_ratio = results['optimal_ratio']
                switch_time = results['optimal_switch_time']
                model_params = self.analysis_results['model_params']
                interaction_type = self.analysis_results['interaction_type']

                t_25 = np.linspace(0, switch_time, 100)
                t_37 = np.linspace(switch_time, model_params['total_time'], 100)

                optimizer = ParameterOptimizer()
                N_A_25, N_B_25 = optimizer.predict_biomass(interaction_type, model_params,
                                                           optimal_ratio, t_25, 25)
                N_A_37, N_B_37 = optimizer.predict_biomass(interaction_type, model_params,
                                                           optimal_ratio, t_37, 37)

                cellulose_yield = model_params['Y_cellulose'] * np.trapezoid(N_A_25, t_25)
                pigment_yield = model_params['Y_pigment'] * np.trapezoid(N_B_37, t_37)

                result_text += f"\n=== Yield Prediction ===\n"
                result_text += f"Predicted cellulose yield: {cellulose_yield:.4f} g/L\n"
                result_text += f"Predicted pigment yield: {pigment_yield:.4f} g/L\n"
                result_text += f"Total value indicator: {results['max_objective']:.4f}\n"

            except Exception as e:
                result_text += f"\nYield prediction calculation failed: {str(e)}\n"

        self.optimization_results.setText(result_text)

    def update_optimization_charts(self):
        if 'optimization_results' in self.optimization_figures:
            self.optimization_canvas.figure = self.optimization_figures['optimization_results']
            self.optimization_canvas.draw()

    def export_growth_curve_svg(self):
        if not self.chart_canvas.figure:
            QMessageBox.warning(self, "Export Error", "No growth curve data to export")
            return

        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save Growth Curve as SVG",
            "", "SVG Files (*.svg)"
        )

        if file_path:
            try:
                if not file_path.endswith('.svg'):
                    file_path += '.svg'

                self.chart_canvas.figure.savefig(file_path, format='svg', dpi=300, bbox_inches='tight')
                QMessageBox.information(self, "Export Successful",
                                        f"Growth curve successfully exported to:\n{file_path}")
            except Exception as e:
                QMessageBox.critical(self, "Export Error", f"Failed to export growth curve: {str(e)}")

    def export_optimization_chart_svg(self):
        if not self.optimization_figures or 'optimization_results' not in self.optimization_figures:
            QMessageBox.warning(self, "Export Error", "No optimization chart data to export")
            return

        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save Optimization Chart as SVG",
            "", "SVG Files (*.svg)"
        )

        if file_path:
            try:
                if not file_path.endswith('.svg'):
                    file_path += '.svg'

                self.optimization_figures['optimization_results'].savefig(
                    file_path, format='svg', dpi=300, bbox_inches='tight'
                )
                QMessageBox.information(self, "Export Successful",
                                        f"Optimization chart successfully exported to:\n{file_path}")
            except Exception as e:
                QMessageBox.critical(self, "Export Error", f"Failed to export optimization chart: {str(e)}")


def main():
    app = QApplication(sys.argv)

    app.setApplicationName("Microbial Co-culture Analysis System")
    app.setApplicationVersion("1.0")

    window = MainWindow()
    window.show()

    sys.exit(app.exec_())


if __name__ == '__main__':
    main()