db = DuckDBClient.of({
summary: FileAttachment("results/summary.parquet"),
stats: FileAttachment("results/stats.parquet"),
basics: {file: FileAttachment("dataset_basics.csv"), header: true},
algorithm_basics: {file: FileAttachment("algorithms_basics.csv"), header: true}
});
k_values = db.sql`select distinct k from summary;`
viewof recall_threshold = Inputs.range([0,1], {step: 0.01, value: 0.9, label: "minimum recall"});
viewof k_value = Inputs.select(k_values.map(d => d.k), {value: 10, label: "value of k"});
fastdata = db.sql`
with
normalized_names as (
select k, regexp_replace(dataset, '-(a2|e2)-', '-') as dataset, algorithm, params, avg_time, qps, recall from summary
),
selected_datasets as (select dataset, type from basics where type in ('in-distribution', 'out-of-distribution') ),
filtered_summary as (
select * from normalized_names natural join selected_datasets natural left join algorithm_basics
where not is_gpu
),
all_datasets as ( select distinct dataset, type from filtered_summary ),
all_algorithms as ( select distinct algorithm from filtered_summary ),
expected_combinations as (
select *
from all_datasets cross join all_algorithms
),
ranked as (
select *, row_number() over (partition by algorithm, dataset order by qps desc) as rank
from filtered_summary
where recall >= ${recall_threshold} and k = ${k_value}
),
scaled as (
select
*,
ifnull(qps, 0) / max(qps) over (partition by dataset, type) as scaled_qps
from ranked
natural right join expected_combinations
),
best_performing as (
select algorithm, dataset, max(scaled_qps) as best_performance
from scaled
group by all
),
algorithms as (
select algorithm, avg(best_performance) as mean_perf
from best_performing
group by all
),
algorithms_ranks as (
select
algorithm, mean_perf,
row_number() over (order by mean_perf desc) as algorithm_rank
from algorithms
),
datasets as (
select dataset, type, avg(rc100) as difficulty
from stats natural join selected_datasets
group by all
),
dataset_ranks as (
select
dataset, difficulty,
row_number() over (order by type, contains(dataset, '-ip'), difficulty desc) as dataset_rank
from datasets
)
select
algorithm,
regexp_replace(dataset, '-[0-9]+-(cosine|normalized|euclidean|ip)', '') as dataset,
type as dataset_type,
dataset_rank,
k,
scaled_qps,
qps,
params,
(algorithm_rank / 6) as fx,
(algorithm_rank % 6) as fy
from scaled natural join algorithms_ranks natural join dataset_ranks
where rank = 1 or rank is null
order by dataset_rank, algorithm, qps;
`
facet_keys = Array.from(d3.union(fastdata.map((d) => d.algorithm)));
dataset_ranks = fastdata.reduce((map, d) => {
map[d.dataset] = d.dataset_rank;
return map;
}, {});
dataset_types = fastdata.reduce((map, d) => {
map[d.dataset] = d.dataset_type;
return map;
}, {});
longitude_domain = Object.keys(dataset_ranks).sort((a, b) => dataset_ranks[a] - dataset_ranks[b]);
// Scales
longitude = d3.scalePoint(
longitude_domain,
[180, -180]
).padding(0.5).align(1);
fmt_qps = d3.format(".2s");
Plot.plot({
width: width,
height: 3*width/4,
marginTop: 20,
marginBottom: 20,
marginLeft: 20,
marginRight: 20,
projection: {
type: "azimuthal-equidistant",
rotate: [0, -90],
// Note: 1.22° corresponds to max. percentage (1.0), plus some room for the labels
domain: d3.geoCircle().center([0, 90]).radius(1.22)()
},
facet: {
data: fastdata,
y: (d) => d.fx, //Math.floor(d.algorithm_rank / 4),
x: (d) => d.fy, // d.algorithm_rank % 4,
axis: null
},
fx: {padding: .1},
fy: {padding: .1},
marks: [
// Facet name
Plot.text(fastdata,
Plot.selectFirst({
text: "algorithm",
frameAnchor: "top-left",
fontWeight: "700",
fontSize: 24
})
),
// grey discs
Plot.geo([1.0, 0.8, 0.6, 0.4, 0.2], {
geometry: (r) => d3.geoCircle().center([0, 90]).radius(r)(),
stroke: "black",
fill: "black",
strokeOpacity: 0.2,
fillOpacity: 0.03,
strokeWidth: 0.5
}),
// colored axes
Plot.link(longitude.domain(), {
x1: longitude,
y1: 90 - 1,
x2: 0,
y2: 90,
stroke: d => (dataset_types[d] == "in-distribution")? "#1f77b4" : "#ff7f0e",
strokeOpacity: 0.5,
strokeWidth: 2.5
}),
// tick labels
Plot.text([0.4, 0.6, 0.8], {
fx: 0, fy: 0,
x: 180,
y: (d) => 90 - d,
dx: 2,
textAnchor: "start",
text: (d) => ( d == 0.8 ? `${100 * d}%` : `${100 * d}%`),
fill: "currentColor",
stroke: "white",
fontSize: 18
}),
// axes labels
Plot.text(longitude.domain(), {
fx: 0, fy: 0,
x: longitude,
y: 90 - 1.07,
text: Plot.identity,
lineWidth: 10,
fontSize: 18
}),
// axes labels, initials
Plot.text(longitude.domain(), {
fx: 0, fy: 0, facet: "exclude",
x: longitude,
y: 90 - 1.09,
text: d => d.slice(0,2),
lineWidth: 5,
fontSize: 18
}),
// areas
Plot.area(fastdata, {
x1: ({ dataset }) => longitude(dataset),
y1: ({ scaled_qps }) => 90 - scaled_qps,
x2: 0,
y2: 90,
fill: "gray",
fillOpacity: 0.25,
stroke: "gray",
curve: "cardinal-closed"
}),
// data values
Plot.dot(fastdata, {
x: (d) => longitude(d.dataset),
y: (d) => 90 - d.scaled_qps,
fill: (d) => (d.dataset_type == "in-distribution")? "#1f77b4" : "#ff7f0e",
stroke: "white"
}),
// interactive labels
Plot.text(
fastdata,
Plot.pointer({
x: ({ dataset }) => longitude(dataset),
y: ({ scaled_qps }) => 90 - scaled_qps,
text: (d) => `${fmt_qps(d.qps)} qps\n(${Math.round(100 * d.scaled_qps)}%)\n${d.params}`,
textAnchor: "start",
dx: 4,
fill: "currentColor",
stroke: "white",
maxRadius: 10,
fontSize: 18
})
)
]
})