Fix GPU interpolation cache lifetime leaks

This commit is contained in:
2026-04-10 10:29:04 +08:00
parent e1a0bff43c
commit c578a15ecd
6 changed files with 241 additions and 99 deletions

View File

@@ -79,6 +79,15 @@ struct CachedInterpPlan
CachedInterpPlan() : nblocks(0) {}
};
struct CachedInterpPlanEntry
{
bool valid;
InterpPlanKey key;
CachedInterpPlan plan;
CachedInterpPlanEntry() : valid(false) {}
};
struct InterpBlockView
{
Block *bp;
@@ -268,6 +277,23 @@ bool should_try_cuda_interp(int ordn, int num_points, int num_var)
return num_points * num_var >= 256;
}
CachedInterpPlanEntry &interp_plan_cache_entry()
{
static CachedInterpPlanEntry cache;
return cache;
}
bool same_interp_plan_key(const InterpPlanKey &lhs, const InterpPlanKey &rhs)
{
return lhs.patch == rhs.patch &&
lhs.x == rhs.x &&
lhs.y == rhs.y &&
lhs.z == rhs.z &&
lhs.NN == rhs.NN &&
lhs.Symmetry == rhs.Symmetry &&
lhs.myrank == rhs.myrank;
}
CachedInterpPlan &get_cached_interp_plan(Patch *patch,
int NN, double **XX,
int Symmetry, int myrank,
@@ -276,8 +302,6 @@ CachedInterpPlan &get_cached_interp_plan(Patch *patch,
bool report_bounds_here,
bool allow_missing_points)
{
static map<InterpPlanKey, CachedInterpPlan, InterpPlanKeyLess> cache;
InterpPlanKey key;
key.patch = patch;
key.x = XX[0];
@@ -287,12 +311,16 @@ CachedInterpPlan &get_cached_interp_plan(Patch *patch,
key.Symmetry = Symmetry;
key.myrank = myrank;
map<InterpPlanKey, CachedInterpPlan, InterpPlanKeyLess>::iterator it = cache.find(key);
if (it != cache.end() && it->second.nblocks == static_cast<int>(block_index.views.size()))
return it->second;
CachedInterpPlanEntry &cache = interp_plan_cache_entry();
if (cache.valid &&
same_interp_plan_key(cache.key, key) &&
cache.plan.nblocks == static_cast<int>(block_index.views.size()))
return cache.plan;
CachedInterpPlan &plan = cache[key];
plan = CachedInterpPlan();
cache.valid = true;
cache.key = key;
cache.plan = CachedInterpPlan();
CachedInterpPlan &plan = cache.plan;
plan.nblocks = static_cast<int>(block_index.views.size());
plan.owner_rank.assign(NN, -1);
plan.owner_block.assign(NN, -1);
@@ -380,6 +408,13 @@ CachedInterpPlan &get_cached_interp_plan(Patch *patch,
return plan;
}
void release_interp_plan_cache_internal()
{
CachedInterpPlanEntry &cache = interp_plan_cache_entry();
cache.valid = false;
cache.plan = CachedInterpPlan();
}
bool run_cuda_interp_for_block(Block *BP,
const vector<InterpVarDesc> &vars,
const vector<int> &point_ids,
@@ -487,9 +522,14 @@ void interpolate_owned_points(MyList<var> *VarList,
}
}
} // namespace
Patch::Patch(int DIM, int *shapei, double *bboxi, int levi, bool buflog, int Symmetry) : lev(levi)
{
void patch_release_interp_plan_cache()
{
release_interp_plan_cache_internal();
}
Patch::Patch(int DIM, int *shapei, double *bboxi, int levi, bool buflog, int Symmetry) : lev(levi)
{
int hbuffer_width = buffer_width;
if (lev == 0)