Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
327 changes: 292 additions & 35 deletions gpu4pyscf/pbc/dft/multigrid_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@
from pyscf.pbc.dft.multigrid import multigrid
from pyscf.pbc.lib.kpts import KPoints
from pyscf.pbc.df.df_jk import _format_kpts_band
from pyscf.pbc.gto.pseudo import pp_int
from pyscf.pbc.lib.kpts_helper import is_gamma_point
from pyscf.gto.mole import ATOM_OF, ANG_OF, NPRIM_OF, NCTR_OF, PTR_EXP, PTR_COEFF
from pyscf.pbc.dft import gen_grid as pbc_gen_grid_cpu
from pyscf.pbc import tools as pbc_tools_cpu
from gpu4pyscf.pbc.gto.pseudo.pp_int import get_pp_nl_gpu
from pyscf.pbc.lib.kpts_helper import is_gamma_point
from gpu4pyscf.lib import logger, utils
from gpu4pyscf.dft import numint
from gpu4pyscf.pbc.df.fft_jk import _format_dms, _format_jks
from gpu4pyscf.pbc.gto.cell import get_Gv
from gpu4pyscf.pbc.tools import pbc as pbc_tools
import gpu4pyscf.pbc.dft.multigrid as multigrid_v1
from gpu4pyscf.lib.cupy_helper import contract, tag_array, load_library, transpose_sum
from gpu4pyscf.lib.cupy_helper import contract, tag_array, load_library, get_avail_mem


__all__ = ['MultiGridNumInt']

Expand Down Expand Up @@ -326,6 +329,8 @@ def assign_pairs_to_blocks(
raise RuntimeError('count_pairs_on_blocks failed')

n_contributing_blocks = int(n_pairs_on_blocks[-1])
if n_contributing_blocks == 0:
return (None, None, None, None)
n_pairs_on_blocks = n_pairs_on_blocks[:-1]
sorted_block_index = cp.asarray(cp.argsort(-n_pairs_on_blocks), dtype=cp.int32)
accumulated_n_pairs_per_block = cp.zeros(n_blocks + 1, dtype=cp.int32)
Expand Down Expand Up @@ -362,7 +367,264 @@ def assign_pairs_to_blocks(
)


def sort_gaussian_pairs(mydf, xc_type="LDA"):
def multi_grids_tasks_lowmem(cell, fft_mesh=None, verbose=None, gamma_point=False, unrestricted=False):
assert multigrid.TASKS_TYPE == 'ke_cut', "rcut scheme not supported yet"
return multi_grids_tasks_for_ke_cut_lowmem(cell, fft_mesh, verbose, gamma_point, unrestricted)


def multi_grids_tasks_for_ke_cut_lowmem(cell, fft_mesh=None, verbose=None, gamma_point=False, unrestricted=False):
"""
Modified from pyscf.pbc.dft.multigrid.multigrid.multi_grids_tasks_for_ke_cut()
This function includes logic to split dense shells if the resulting fock matrix requires too much GPU memory.
"""
log = logger.new_logger(cell, verbose)
if fft_mesh is None:
fft_mesh = cell.mesh

# Split shells based on rcut
rcuts_pgto, kecuts_pgto = multigrid._primitive_gto_cutoff(cell)
ao_loc = cell.ao_loc_nr()

# cell that needs dense integration grids
def make_cell_dense_exp(shls_dense, ke0, ke1):
cell_dense = cell.copy(deep=False)
cell_dense._bas = cell._bas.copy()
cell_dense._env = cell._env.copy()

rcut_atom = [0] * cell.natm
ke_cutoff = 0
for ib in shls_dense:
ke = kecuts_pgto[ib]
idx = np.where((ke0 < ke) & (ke <= ke1))[0]
nprim1 = len(idx)
cs = cell._libcint_ctr_coeff(ib)
nprim, nc = cs.shape
if nprim1 < nprim: # no pGTO splitting within the shell
pexp = cell._bas[ib,PTR_EXP]
pcoeff = cell._bas[ib,PTR_COEFF]
cs1 = cs[idx]
cell_dense._env[pcoeff:pcoeff+cs1.size] = cs1.T.ravel()
cell_dense._env[pexp:pexp+nprim1] = cell.bas_exp(ib)[idx]
cell_dense._bas[ib,NPRIM_OF] = nprim1

ke_cutoff = max(ke_cutoff, ke[idx].max())

ia = cell.bas_atom(ib)
rcut_atom[ia] = max(rcut_atom[ia], rcuts_pgto[ib][idx].max())
cell_dense._bas = cell_dense._bas[shls_dense]
ao_idx = np.hstack([np.arange(ao_loc[i], ao_loc[i+1])
for i in shls_dense])
cell_dense.rcut = max(rcut_atom)
return cell_dense, ao_idx, ke_cutoff, rcut_atom

# cell that needs sparse integration grids
def make_cell_sparse_exp(shls_sparse, ke0):
cell_sparse = cell.copy(deep=False)
cell_sparse._bas = cell._bas.copy()
cell_sparse._env = cell._env.copy()

for ib in shls_sparse:
idx = np.where(kecuts_pgto[ib] <= ke0)[0]
nprim1 = len(idx)
cs = cell._libcint_ctr_coeff(ib)
nprim, nc = cs.shape
if nprim1 < nprim: # no pGTO splitting within the shell
pexp = cell._bas[ib,PTR_EXP]
pcoeff = cell._bas[ib,PTR_COEFF]
cs1 = cs[idx]
cell_sparse._env[pcoeff:pcoeff+cs1.size] = cs1.T.ravel()
cell_sparse._env[pexp:pexp+nprim1] = cell.bas_exp(ib)[idx]
cell_sparse._bas[ib,NPRIM_OF] = nprim1
cell_sparse._bas = cell_sparse._bas[shls_sparse]
ao_idx = np.hstack([np.arange(ao_loc[i], ao_loc[i+1])
for i in shls_sparse])
return cell_sparse, ao_idx

def get_nao_of_extracted_cell(shls_dense, original_cell, kecuts_pgto, ke0 = 0, ke1 = np.inf):
nao_sph_nctr = 0 # Actual number of orbitals
nao_cart_nprim = 0 # Number of primitives used for kernel
for i_bas in shls_dense:
bas = original_cell._bas[i_bas]
ang = bas[ANG_OF]
per_nao_sph = (2 * ang + 1)
per_nao_cart = (ang + 2) * (ang + 1) // 2
nctr = bas[NCTR_OF]

# nprim = bas[NPRIM_OF]
ke = kecuts_pgto[i_bas]
idx = np.where((ke0 < ke) & (ke <= ke1))[0]
nprim = len(idx)

nao_sph_nctr += per_nao_sph * nctr
nao_cart_nprim += per_nao_cart * nprim
return nao_sph_nctr, nao_cart_nprim

# Compute the max possible n_difference_images for memory partition
if gamma_point:
n_difference_images = 0
else:
max_neighboring_images = cp.asarray(gto.eval_gto.get_lattice_Ls(cell))
fake_kpts = np.array([[0.5,0.5,0.5]])
img_phase = image_phase_for_kpts(cell, max_neighboring_images, fake_kpts)
phase_diff_among_images, image_pair_difference_index = img_phase
n_difference_images = int(phase_diff_among_images.shape[1])
n_channel = 2 if unrestricted else 1

a = cell.lattice_vectors()
if abs(a-np.diag(a.diagonal())).max() < 1e-12:
init_mesh = multigrid.INIT_MESH_ORTH
else:
init_mesh = multigrid.INIT_MESH_NONORTH
ke_cutoff_min = pbc_tools_cpu.mesh_to_cutoff(cell.lattice_vectors(), init_mesh)
ke_cutoff_max = max([ke.max() for ke in kecuts_pgto])
ke1 = ke_cutoff_min.min()
ke_delimeter = [0, ke1]
while ke1 < ke_cutoff_max:
ke1 *= multigrid.KE_RATIO
ke_delimeter.append(ke1)

tasks = []
for ke0, ke1 in zip(ke_delimeter[:-1], ke_delimeter[1:]):
# shells which have high exps (small rcut)
shls_dense = [ib for ib, ke in enumerate(kecuts_pgto)
if np.any((ke0 < ke) & (ke <= ke1))]
if len(shls_dense) == 0:
continue

mesh = pbc_tools_cpu.cutoff_to_mesh(a, ke1)
if multigrid.TO_EVEN_GRIDS:
mesh = int((mesh+1)//2) * 2 # to the nearest even number

ke1_capped = ke1
if np.all(mesh >= fft_mesh):
# Including all rest shells
shls_dense = [ib for ib, ke in enumerate(kecuts_pgto)
if np.any(ke0 < ke)]
ke1_capped = ke_cutoff_max+1

dense_nao, dense_nprim_cart = get_nao_of_extracted_cell(shls_dense, cell, kecuts_pgto, ke0, ke1_capped)
dense_nao, dense_nprim_cart = int(dense_nao), int(dense_nprim_cart)

# shells which have low exps (big rcut)
shls_sparse = [ib for ib, ke in enumerate(kecuts_pgto)
if np.any(ke <= ke0)]

if len(shls_sparse) == 0:
sparse_nao, sparse_nprim_cart = 0, 0
else:
sparse_nao, sparse_nprim_cart = get_nao_of_extracted_cell(shls_sparse, cell, kecuts_pgto, 0, ke0)
sparse_nao, sparse_nprim_cart = int(sparse_nao), int(sparse_nprim_cart)

sum_nprim_cart = dense_nprim_cart + sparse_nprim_cart

fock_size = n_channel * n_difference_images * dense_nprim_cart * sum_nprim_cart
if gamma_point:
fock_nbytes_per_element = np.dtype(np.float64).itemsize
else:
# Why does it require a float64 and a complex128? Because when rotating the fock in image space to k space,
# it converts the float64 fock matrix into complex128, so at that point, we need to store both.
fock_nbytes_per_element = np.dtype(np.float64).itemsize + np.dtype(np.complex128).itemsize
fock_nbytes = fock_size * fock_nbytes_per_element

# At this stage almost no other memory is allocated,
# and this number can be much lower when the fock matrix is actually built.
available_gpu_memory = get_avail_mem()
available_gpu_memory = int(available_gpu_memory * 0.2)

n_split = (fock_nbytes + available_gpu_memory - 1) // available_gpu_memory
if n_split > 1:
log.warn(f"Warning: at dense shell ke range ({ke0}, {ke1_capped}], "
f"the fock matrix size ({fock_nbytes / 2**30} GiB) is too large, "
f"so the dense shells are split into {n_split} parts")

def split_list_evenly(lst, n_piece):
N = len(lst)
if n_piece >= N:
n_piece = N
q, r = divmod(N, n_piece)
out = []
offset = 0
for i in range(n_piece):
size = q + (1 if i < r else 0)
out.append(lst[offset:offset + size])
offset += size
return out

shls_dense_split = split_list_evenly(shls_dense, n_split)
shls_dense_cross = []

mesh = np.min([mesh, fft_mesh], axis=0)

if len(shls_sparse) == 0:
cell_sparse = None
ao_idx_sparse = []
else:
cell_sparse, ao_idx_sparse = make_cell_sparse_exp(shls_sparse, ke0)
cell_sparse.mesh = mesh

if cell_sparse is None:
grids_sparse = None
else:
grids_sparse = pbc_gen_grid_cpu.UniformGrids(cell_sparse)
grids_sparse.ao_idx = ao_idx_sparse

for shls_dense in shls_dense_split:
cell_dense, ao_idx_dense, _, _ = \
make_cell_dense_exp(shls_dense, ke0, ke1_capped)
cell_dense.mesh = mesh

grids_dense = pbc_gen_grid_cpu.UniformGrids(cell_dense)
grids_dense.ao_idx = ao_idx_dense

log.debug('mesh %s nao dense/sparse %d %d rcut %g',
mesh, len(ao_idx_dense), len(ao_idx_sparse), cell_dense.rcut)

if len(shls_dense_cross) > 0:
cell_dense_cross, ao_idx_dense_cross, _, _ = \
make_cell_dense_exp(shls_dense_cross, ke0, ke1_capped)
cell_dense_cross.mesh = mesh

if cell_sparse is None:
grids_lower_triangular = pbc_gen_grid_cpu.UniformGrids(cell_dense_cross)
grids_lower_triangular.ao_idx = ao_idx_dense_cross
else:
cell_lower_triangular = cell_sparse + cell_dense_cross
cell_lower_triangular._bas[cell_sparse.nbas:, ATOM_OF] -= len(cell_sparse._atm)

# Sort by atom first (later index has higher priority) to make aoslices work
bas_sort_by_atom_index = np.lexsort((cell_lower_triangular._bas[:,ANG_OF], cell_lower_triangular._bas[:,ATOM_OF]))

reverse_sort = np.argsort(bas_sort_by_atom_index)
ao_sort_by_atom_index = [[] for _ in range(cell_lower_triangular.nbas)]
ao_offset = 0
for i in range(cell_lower_triangular.nbas):
L = cell_lower_triangular._bas[i, ANG_OF]
nL = ((L+1)*(L+2)//2) if cell_lower_triangular.cart else (2*L+1)
nctr = cell_lower_triangular._bas[i, NCTR_OF]
ao_sort_by_atom_index[reverse_sort[i]] = np.arange(nL * nctr) + ao_offset
ao_offset += nL * nctr

ao_sort_by_atom_index = [int(item) for row in ao_sort_by_atom_index for item in row]
assert len(ao_sort_by_atom_index) == cell_lower_triangular.nao

cell_lower_triangular._bas = cell_lower_triangular._bas[bas_sort_by_atom_index]
ao_idx_lower_triangular = np.concatenate((ao_idx_sparse, ao_idx_dense_cross))
ao_idx_lower_triangular = ao_idx_lower_triangular[ao_sort_by_atom_index]
grids_lower_triangular = pbc_gen_grid_cpu.UniformGrids(cell_lower_triangular)
grids_lower_triangular.ao_idx = ao_idx_lower_triangular

tasks.append([grids_dense, grids_lower_triangular])
else:
tasks.append([grids_dense, grids_sparse])

shls_dense_cross.extend(shls_dense)

if np.all(mesh >= fft_mesh):
break
return tasks


def sort_gaussian_pairs(mydf, xc_type="LDA", gamma_point=False, unrestricted=False):
cell = mydf.cell
log = logger.new_logger(cell)
t0 = log.init_timer()
Expand All @@ -386,7 +648,7 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"):

tasks = getattr(mydf, "tasks", None)
if tasks is None:
tasks = multigrid.multi_grids_tasks(cell, mydf.mesh, log)
tasks = multi_grids_tasks_lowmem(cell, mydf.mesh, log, gamma_point, unrestricted)
mydf.tasks = tasks

t0 = log.timer("task generation", *t0)
Expand Down Expand Up @@ -437,7 +699,7 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"):

grouped_cell = equivalent_cell_in_localized + equivalent_cell_in_diffused

grouped_cell._bas[n_primitive_gtos_in_localized:, 0] -= len(
grouped_cell._bas[n_primitive_gtos_in_localized:, ATOM_OF] -= len(
subcell_in_localized_region._atm
)

Expand Down Expand Up @@ -540,6 +802,8 @@ def sort_gaussian_pairs(mydf, xc_type="LDA"):
env,
has_warned_instability
)
if gaussian_pair_indices is None:
continue
t1 = log.timer_debug2(
"assigning pairs to blocks in angular pair"
+ str((i_angular, j_angular)),
Expand Down Expand Up @@ -918,36 +1182,29 @@ def convert_xc_on_g_mesh_to_fock(
fock_slice = cp.einsum("nkpq,pi->nkiq", fock_slice, pairs["coeff_in_localized"])
fock_slice = cp.einsum("nkiq,qj->nkij", fock_slice, pairs["concatenated_coeff"])

# While mathematically it is correct to have concatenated
# ao indices in the addition, but it is possible that the ao
# indices overlap between localized gaussians and diffused gaussians
# (imagine two gaussians within a single shell, say, C2s).
# In this case, the addition to the same place requires atomic
# operation, while I guess in the cupy code it is assumed that
# the indices do not overlap, and hence no atomic guard.
# Anyway, the numerical result will be wrong if we use
# concatenated ao indices.
fock[
:,
:,
pairs["ao_indices_in_localized"][:, None],
pairs["ao_indices_in_localized"],
] += fock_slice[:, :, :, :n_ao_in_localized]
fock[
:,
:,
pairs["ao_indices_in_localized"][:, None],
pairs["ao_indices_in_diffused"],
] += fock_slice[:, :, :, n_ao_in_localized:]
def atomic_add_complex128(dst, idx, src):
if src.dtype == cp.float64:
assert dst.dtype == cp.float64
cp.add.at(dst, idx, src)
else:
assert dst.dtype == cp.complex128
assert src.dtype == cp.complex128
# Cupy doesn't allow atomic addition for complex128, so we need to add real and imag parts separately.
cp.add.at(dst.real, idx, src.real)
cp.add.at(dst.imag, idx, src.imag)

atomic_add_complex128(fock,
(slice(None), slice(None), pairs["ao_indices_in_localized"][:, None], pairs["ao_indices_in_localized"][None, :]),
fock_slice[:, :, :, :n_ao_in_localized])

atomic_add_complex128(fock,
(slice(None), slice(None), pairs["ao_indices_in_localized"][:, None], pairs["ao_indices_in_diffused"][None, :]),
fock_slice[:, :, :, n_ao_in_localized:])

if hermi == 1:
fock[
:,
:,
pairs["ao_indices_in_diffused"][:, None],
pairs["ao_indices_in_localized"],
] += (
fock_slice[:, :, :, n_ao_in_localized:].transpose(0, 1, 3, 2).conj()
)
atomic_add_complex128(fock,
(slice(None), slice(None), pairs["ao_indices_in_diffused"][:, None], pairs["ao_indices_in_localized"][None, :]),
fock_slice[:, :, :, n_ao_in_localized:].transpose(0, 1, 3, 2).conj())
else:
raise NotImplementedError

Expand Down
Loading
Loading