Identify unused columns in Snowflake and other data warehouses
Identify unused columns in your data warehouse to reduce cost and improve performance. We provide two ways - one using a Snowflake query and the other a Python script.
One of the most underrated ways to reduce data warehouse costs and improve performance is to eliminate unused data. In addition to reducing costs, this also has the side benefit of simplifying your data and reducing errors and confusion. The SELECT team did a great write up of how to identify unused tables in Snowflake. I wanted to take it a step further by identifying unused columns. It turns out that the same table that can be used to identify tables used by a query, snowflake.account_usage.access_history, can also be used to identify unused columns. The high level flow is very similar to finding unused tables:
Use the access history to extract the columns used
Get a list of all the columns from the information schema
Take the difference between the two to identify the unused columns
We modified the query provided by the SELECT team to work with columns and start off the same way — by using the access_history table — but instead of pulling the tables, we have to do an additional lateral flatten to extract the columns. We apply the same set of filters to limit this to columns that belong to tables owned by us, and then run an outer join against the full schema to identify columns that haven’t been accessed in the specified period.
with
access_history as (
select *
from snowflake.account_usage.access_history
where query_start_time > current_date - interval '1 month'
),
access_history_flattened as (
select
access_history.query_id,
access_history.query_start_time,
access_history.user_name,
objects_accessed.value:objectId::integer as table_id,
objects_accessed.value:objectName::text as object_name,
cols.value as cols,
cols.value:columnName as column_name,
cols.value:columnId as column_id
from access_history,
lateral flatten(access_history.base_objects_accessed) as objects_accessed,
lateral flatten(input => objects_accessed.value:columns) as cols
where objects_accessed.value:objectDomain::text = 'Table'
and objects_accessed.value:objectId::integer is not null
),
column_access_history as (
select
ahf.query_id,
ahf.query_start_time,
ahf.user_name,
ahf.column_id,
cols.column_name,
cols.table_name,
cols.table_schema,
cols.table_catalog
from access_history_flattened ahf
join snowflake.account_usage.columns cols on ahf.column_id = cols.column_id
)
select all_cols.table_catalog, all_cols.table_schema, all_cols.table_name, all_cols.column_name
from snowflake.account_usage.columns all_cols
left join column_access_history cah on all_cols.column_id = cah.column_id
where all_cols.table_schema not in ('INFORMATION_SCHEMA')
and cah.column_id is null
group by all;
Unfortunately, the above only works if you pay Snowflake for an Enterprise account as Standard accounts do not have access to the access_history table. That doesn’t sit well with us and we believe all Snowflake users should be able to see their unused columns. On top of that, why limit ourselves to Snowflake? At the end of the day, as long as you can extract columns from a query and have access to the schema you can do this in every dialect.
We wrote a simple script that does just that. It’s more involved than the query above and needs to be tested on a wider range of queries, but functions like the query above. The biggest difference is that rather than having the columns available in a nicely structured JSON object, we have to extract them from the query text, which we do using the powerful sqlglot library. On top of that, we have to add the catalog and schema to the table if it’s not present (“qualify”) and use sqlglot’s qualify function to do the same to the columns, in addition to expanding the * into an explicit column list. Additionally, we use sqlglot to identify the physical columns rather than the ones that are computed. Note that this approach has a few limitations - as mentioned above it needs to be tested on more non-select query types, it will flag columns as unused if they are only used indirectly in a downstream view, and likely more - so tread carefully but we hope it acts as a good starting point.
from sqlglot import parse_one, exp
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.scope import find_all_in_scope, build_scope
from collections import Counter
import argparse
import csv
def read_queries(query_file):
# Read queries from a CSV file and return a list of dictionaries where each key is a column in the CSV
with open(query_file) as f:
csv_reader = csv.reader(f)
header = next(csv_reader)
return [dict(zip(header, row)) for row in csv_reader]
def read_info_schema(info_schema_file):
# Read the info schema from a CSV file and return it as both a nested dictionary and a flat list
# Format is: catalog -> schema -> table name -> column name
schema = {}
flat_schema = []
with open(info_schema_file) as f:
csv_reader = csv.reader(f)
next(csv_reader) # Skip header
for row in csv_reader:
catalog, schema_name, table_name, column_name = map(str.upper, row)
if catalog not in schema:
schema[catalog] = {}
if schema_name not in schema[catalog]:
schema[catalog][schema_name] = {}
if table_name not in schema[catalog][schema_name]:
schema[catalog][schema_name][table_name] = {}
schema[catalog][schema_name][table_name][column_name] = "DUMMY"
flat_schema.append((catalog, schema_name, table_name, column_name))
return schema, flat_schema
def extract_columns(query_text, database_name, catalog_name, schema):
# Extract the columns from a query that map to actual columns in a table
# Based on https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md
parsed = parse_one(query_text, dialect="snowflake")
qualified = qualify(
parsed, schema=schema, dialect="snowflake"
) # Qualify (add schema) and expand * to explicit columns
root = build_scope(qualified)
# This is confusing due to naming conventions. We basically want to make sure every table is fully qualified
# sqlglot has {catalog: {db: {table: {col: type}}}} convention
# Snowflake has {database_name: {schema_name: {table: {col: type}}}}
# So we do database_name (SF) -> catalog (sqlglot), schema_name (SF) -> db (sqlglot)
for source in root.sources:
s = root.sources[source]
if type(s) == exp.Table:
if "db" not in s.args or not s.args["db"]:
s.set("db", exp.Identifier(this=catalog_name, quoted=True))
if "catalog" not in s.args or not s.args["catalog"]:
s.set("catalog", exp.Identifier(this=database_name, quoted=True))
columns = []
for column in find_all_in_scope(root.expression, exp.Column):
if column.table not in root.sources:
continue
table = root.sources[column.table]
if type(table) != exp.Table:
continue
columns.append(
(
table.catalog,
table.db,
table.name,
column.this.this,
)
)
return columns
def summarize_columns(columns):
# Return a dictionary of column to counts
# Flatten the col vals
cols = [item for sublist in columns for item in sublist]
return Counter(cols)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Identify unused columns in SQL queries"
)
parser.add_argument("--query_file", help="Query file to analyze")
parser.add_argument(
"--info_schema_file",
help="File containing the information schema for the database",
)
args = parser.parse_args()
queries = read_queries(args.query_file)
print(f"Read {len(queries)} queries from {args.query_file}")
info_schema, info_schema_flat = read_info_schema(args.info_schema_file)
print(
f"Read {len(info_schema_flat)} information schema rows from {args.info_schema_file}"
)
cols = [
extract_columns(
query["query_text"],
database_name=query["database_name"].upper(),
catalog_name=query["schema_name"].upper(),
schema=info_schema,
)
for query in queries
]
col_counts = summarize_columns(cols)
# Print the most common columns in a human readable format with one column per line
print("Most common columns (20):")
for col, count in col_counts.most_common(20):
print(f"{col}: {count}")
# Identify columns that are never used by comparing the columns in the info schema to the columns in the queries
info_schema_cols = set(info_schema_flat)
used_cols = set(col_counts.keys())
unused_cols = sorted(info_schema_cols - used_cols)
print(f"Unused columns ({len(unused_cols)}):")
for col in unused_cols:
print(col)