In [1]:
# Import necessary libraries
from IPython.display import HTML, Image
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import seaborn as sns
from IPython.display import display, HTML
def hide_code_toggle():
display(HTML('''
<style>
.code-toggle {
position: fixed;
top: 20px;
left: 20px;
z-index: 1000;
padding: 10px 17px;
background: #1570FF; /* Bright blue color */
color: white;
border: none;
border-radius: 6px;
cursor: pointer;
font-size: 14px;
font-weight: bold;
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
transition: all 0.3s ease;
}
.code-toggle:hover {
background: #1976D2;
box-shadow: 0 4px 8px rgba(0,0,0,0.3);
transform: translateY(-1px);
}
.code-toggle:active {
transform: translateY(1px);
}
.code-container {
margin: 10px 0;
transition: all 0.3s ease;
}
</style>
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
<script>
$(document).ready(function() {
// Create toggle button
var toggleButton = $('<button class="code-toggle">â–¼ Toggle Code</button>');
$('body').prepend(toggleButton);
// Wrap all code blocks in containers
$('pre').each(function() {
$(this).wrap('<div class="code-container"></div>');
});
// Set initial state
var codeVisible = true;
// Toggle functionality
toggleButton.click(function() {
codeVisible = !codeVisible;
$('.code-container').slideToggle(300);
$(this).html(codeVisible ? 'â–¼ Hide Code' : 'â–² Show Code');
$(this).toggleClass('code-hidden', !codeVisible);
});
});
</script>
'''))
hide_code_toggle()
# Configuration
DATA_FILE = r"C:/Users/11/Desktop/Python/Python Project/Trade_China_2000-2023/trade_data_updated.xlsx"
YEAR_START = 2000
YEAR_END = 2024
OUTPUT_SURPLUS_GIF = r"C:\Users\11\PycharmProjects\PythonProject2\top10_surplus_animation_smooth.gif"
FPS = 7 # Frames per second for saving the GIF
TRANSITION_STEPS = 7 # Number of intermediate frames for smooth transitions
MIN_VALUE = 1 # Avoids zero values in visualization
BACKGROUND_COLOR = "#f4f4f4" # Light grey background
# Helper functions
def clean_country_names(df):
"""Clean and standardize country names."""
if 'Country' in df.columns:
df['Country'] = df['Country'].astype(str).str.strip()
df['Country'] = df['Country'].str.replace(r'\s+', ' ', regex=True)
else:
raise KeyError("Missing 'Country' column in the dataset.")
return df
def load_and_prepare_data():
"""
Load and process the trade data, aggregating by country and year,
and filtering for the top 10 countries with the largest difference between export and import.
"""
print("Loading and preparing data...")
try:
df = pd.read_excel(DATA_FILE)
df = clean_country_names(df)
# Inspect columns to identify 'TradeType'
if 'TradeType' not in df.columns:
print("Warning: 'TradeType' column not found, checking 'Unnamed: 0'.")
if 'Unnamed: 0' in df.columns:
df.rename(columns={'Unnamed: 0': 'TradeType'}, inplace=True)
else:
df['TradeType'] = 'Export' # Default to Export if no 'TradeType' column exists.
# Identify year columns (as integers or strings)
year_columns = [year for year in range(YEAR_START, YEAR_END + 1) if year in df.columns]
if not year_columns:
year_columns = [str(year) for year in range(YEAR_START, YEAR_END + 1) if str(year) in df.columns]
if not year_columns:
raise KeyError("Year columns not found in the dataset.")
# Use 'Country' as the identifier
id_vars = ['Country']
if 'TradeType' in df.columns:
id_vars.insert(0, 'TradeType')
df = df.melt(id_vars=id_vars, value_vars=year_columns,
var_name='Year', value_name='TradeVolume')
df['Year'] = df['Year'].astype(int)
df['TradeVolume'] = pd.to_numeric(df['TradeVolume'], errors='coerce').fillna(0)
# Separate export and import data
export_data = df[df['TradeType'] == 'Export']
import_data = df[df['TradeType'] == 'Import']
# Aggregate export and import by country and year
export_data = export_data.groupby(['Country', 'Year'], as_index=False)['TradeVolume'].sum()
import_data = import_data.groupby(['Country', 'Year'], as_index=False)['TradeVolume'].sum()
# Merge export and import data on 'Country' and 'Year'
trade_data = pd.merge(export_data, import_data, on=['Country', 'Year'], suffixes=('_export', '_import'))
# Calculate the surplus (difference between exports and imports)
trade_data['Surplus'] = trade_data['TradeVolume_export'] - trade_data['TradeVolume_import']
# Get the top 10 countries with the largest average surplus
top_countries = trade_data.groupby('Country', as_index=False)['Surplus'].mean() \
.sort_values(by='Surplus', ascending=False).head(10)['Country'].tolist()
print("Top 10 Countries by Trade Surplus:", top_countries)
# Filter data to only include the top 10 countries
trade_data = trade_data[trade_data['Country'].isin(top_countries)]
print("Data successfully processed\n")
return trade_data, top_countries
except Exception as e:
print(f"Error: {str(e)}")
return None, None
def create_histogram(ax, data, year, title, y_max_limit, top_countries, is_interpolated=False):
"""Create a bar chart for surplus per country."""
ax.clear()
ax.set_facecolor(BACKGROUND_COLOR)
ax.set_ylabel('Surplus')
# Use the passed data directly if it is already interpolated; otherwise, filter by the exact year.
if not is_interpolated:
year_data = data[data['Year'] == year]
else:
year_data = data
if year_data.empty:
ax.text(0.5, 0.5, "No Data Available", ha='center', va='center',
color='red', fontsize=12)
return
# Group data by country and ensure bars appear in the same order for the top countries.
grouped = year_data.groupby('Country', as_index=False)['Surplus'].sum()
# Filter out countries with negative surplus values
grouped = grouped[grouped['Surplus'] > 0]
# Reindex the DataFrame to ensure the top countries are listed, even if some have no surplus
grouped = grouped.set_index('Country').reindex(top_countries, fill_value=0).reset_index()
ax.set_ylim(0, y_max_limit) # Only show positive surplus
# Plot only positive surplus bars
colors = sns.color_palette("viridis", len(grouped))
ax.bar(grouped['Country'], grouped['Surplus'], color=colors, edgecolor='black', alpha=1.0)
for i, row in grouped.iterrows():
ax.text(i, row['Surplus'] + 0.05, f"{row['Surplus']:,.0f}",
ha='center', va='bottom', fontsize=9)
plt.xticks(rotation=30, ha="right")
ax.grid(True, axis='y', linestyle='--', alpha=0.6)
def generate_gif(output_gif):
data, top_countries = load_and_prepare_data()
if data is None or top_countries is None:
print("Failed to load data. Exiting.")
return
# Get the unique years available in the dataset and exclude 2024
years = sorted([year for year in data['Year'].unique() if year < 2024]) # <-- Change here
print("Years available in the dataset:", years)
if not years:
print("No years found in the dataset. Exiting.")
return
# ... rest of the code remains unchanged ...
# Find the maximum surplus for y-axis scaling
max_surplus = data['Surplus'].max()
y_max_limit = max_surplus * 1.1 # Add some padding for readability
fig, ax = plt.subplots(figsize=(10, 7))
# Calculate total frames (including intermediate steps) for smooth transitions
total_frames = (len(years) - 1) * TRANSITION_STEPS + 1
def update(frame):
# Determine which two years to interpolate between
index = frame // TRANSITION_STEPS
fraction = (frame % TRANSITION_STEPS) / TRANSITION_STEPS
if index >= len(years) - 1:
# Last frame: use the final year's data.
current_year = years[-1]
displayed_year = current_year
data_frame = data[data['Year'] == current_year]
else:
year_current = years[index]
year_next = years[index + 1]
displayed_year = year_current + fraction # For display purposes
def interpolate_data(data, current_year, next_year, fraction):
interp_dict = {'Country': [], 'Surplus': []}
for country in top_countries:
surplus_current = data[(data['Year'] == current_year) & (data['Country'] == country)]
surplus_current = surplus_current['Surplus'].iloc[0] if not surplus_current.empty else 0
surplus_next = data[(data['Year'] == next_year) & (data['Country'] == country)]
surplus_next = surplus_next['Surplus'].iloc[0] if not surplus_next.empty else 0
# Linear interpolation
interp_surplus = (1 - fraction) * surplus_current + fraction * surplus_next
interp_dict['Country'].append(country)
interp_dict['Surplus'].append(interp_surplus)
return pd.DataFrame(interp_dict)
data_frame = interpolate_data(data, year_current, year_next, fraction)
fig.suptitle(f"China's Top 10 Trade Surplus Countries(In Millions) - {int(displayed_year)}", fontsize=18, fontweight="bold",
y=0.95)
create_histogram(ax, data_frame, displayed_year, "Trade Surplus", y_max_limit, top_countries,
is_interpolated=True)
print("Generating smooth animation for surplus...")
ani = FuncAnimation(fig, update, frames=total_frames, interval=100, blit=False)
# Show the animation in Jupyter Notebook
display(HTML(ani.to_jshtml()))
# Ensure the output directory exists
output_dir = os.path.dirname(output_gif)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
try:
ani.save(output_gif, writer='pillow', fps=FPS, dpi=150)
print(f"Success! Saved as {output_gif}")
except Exception as e:
print(f"GIF Save Error: {str(e)}")
finally:
plt.close(fig)
# Call the function for generating the surplus GIF
generate_gif(OUTPUT_SURPLUS_GIF)
Loading and preparing data... Warning: 'TradeType' column not found, checking 'Unnamed: 0'. Top 10 Countries by Trade Surplus: ['United States', 'Netherlands', 'India', 'United Kingdom', 'Vietnam', 'Mexico', 'Singapore', 'United Arab Emirates', 'Spain', 'Turkey'] Data successfully processed Years available in the dataset: [2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023] Generating smooth animation for surplus...