#!/usr/bin/env python3
# coding: utf-8
"""DockQ scoring and interface-metric helpers.
The public functions in this module compute RMSD, contact fractions, and the
combined DockQ score for protein-protein and protein-nucleic interfaces.
"""
import itertools
import logging
import math
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from ..alignment import align_seq as _align_seq
from ..core import align_seq_based, align_index_based, coor_align, get_common_atoms, rmsd as core_rmsd
from ..select import remove_incomplete_backbone_residues
__all__ = ["rmsd", "interface_rmsd", "native_contact", "dockQ", "dockQ_multimer"]
logger = logging.getLogger(__name__)
def _count_chain_combinations(clusters):
"""Return the number of valid chain-map permutations for *clusters*.
Mirrors ``DockQ.count_chain_combinations``. *clusters* maps each key
(native or model chain) to the list of compatible counterpart chains.
The formula counts the number of ways to assign one counterpart per key
without repeats, grouping identical candidate lists together:
.. math::
\\prod_{\\text{unique cluster}} P(|\\text{cluster}|, k)
where *k* is the number of keys whose candidate list equals that cluster.
"""
cluster_tuples = [tuple(sorted(v)) for v in clusters.values()]
n_combos = 1
for cluster, k in Counter(cluster_tuples).items():
n = len(cluster)
if n < k:
return 0 # impossible to assign without repeats
n_combos *= math.factorial(n) // math.factorial(n - k)
return int(n_combos)
def _sequence_identity(seq1, seq2):
"""Return sequence identity using Smith-Waterman alignment.
Aligns the two sequences with the C++ Smith-Waterman implementation and
returns the fraction of identical positions within the aligned region
(positions where neither sequence has a gap). Normalising by the aligned
length rather than the full sequence length means that N/C-terminal
extensions do not artificially reduce the score.
"""
s1 = seq1.replace("-", "").upper()
s2 = seq2.replace("-", "").upper()
if not s1 or not s2:
return 0.0
a1, a2, _ = _align_seq(s1, s2)
aligned = [(c1, c2) for c1, c2 in zip(a1, a2) if c1 != "-" and c2 != "-"]
if not aligned:
return 0.0
matches = sum(c1 == c2 for c1, c2 in aligned)
return matches / len(aligned)
def _count_clashes(coor, rec_chains, lig_chains, clash_cutoff=2.0):
"""Count (receptor, ligand) residue pairs with any atom within *clash_cutoff* A."""
rec_sel = " ".join(rec_chains)
lig_sel = " ".join(lig_chains)
clashes_list = []
for frame_index in range(len(coor.models)):
close = coor.select_atoms(
f"(chain {rec_sel} and within {clash_cutoff} of chain {lig_sel}) or "
f"(chain {lig_sel} and within {clash_cutoff} of chain {rec_sel})",
frame=frame_index,
)
rec_close = close.select_atoms(
f"chain {rec_sel} and within {clash_cutoff} of chain {lig_sel}"
)
count = 0
for residue in np.unique(rec_close.uniq_resid):
lig_close = close.select_atoms(
f"chain {lig_sel} and within {clash_cutoff} of residue {residue}"
)
count += len(np.unique(lig_close.uniq_resid))
clashes_list.append(count)
return clashes_list
def _check_interface_viable(coor, native_coor, rec_m, lig_m, rec_n, lig_n, back_atom, cutoff=10.0, _native_has_interface=None, _model_backbone_cache=None):
"""Cheap pre-check: return (True, "") when the interface is scoreable.
Avoids the expensive ``dockQ`` call (and its logged warnings) when it
would obviously fail because:
- a model chain has no protein backbone atoms, or
- the native structure has no interface residues between the two chains.
"""
back_list = " ".join(back_atom)
if _model_backbone_cache is not None:
rec_ok = _model_backbone_cache.get(rec_m)
lig_ok = _model_backbone_cache.get(lig_m)
else:
rec_ok = lig_ok = None
if rec_ok is None:
model_rec = coor.select_atoms(f"protein and chain {rec_m} and name {back_list}")
rec_ok = len(np.unique(model_rec.models[0].uniq_resid)) > 0
if not rec_ok:
return False, f"no backbone atoms in model chain {rec_m}"
if lig_ok is None:
model_lig = coor.select_atoms(f"protein and chain {lig_m} and name {back_list}")
lig_ok = len(np.unique(model_lig.models[0].uniq_resid)) > 0
if not lig_ok:
return False, f"no backbone atoms in model chain {lig_m}"
if _native_has_interface is None:
nat_interface = native_coor.select_atoms(
f"chain {rec_n} and within {cutoff} of chain {lig_n}"
)
_native_has_interface = len(np.unique(nat_interface.models[0].uniq_resid)) > 0
if not _native_has_interface:
return False, f"no interface residues between native chains {rec_n} and {lig_n}"
return True, ""
def _get_chain_types(coor, chains):
"""Return ``{chain: type}`` for each chain.
Types are ``'protein'``, ``'dna'``, ``'rna'``, or ``'other'``. Uses
``select_atoms`` with the built-in molecule-type keywords so it works
independently of ``get_aa_seq()`` (which only returns protein chains).
"""
types = {}
for ch in chains:
if len(np.unique(coor.select_atoms(f"protein and chain {ch}").models[0].uniq_resid)) > 0:
types[ch] = "protein"
elif len(np.unique(coor.select_atoms(f"dna and chain {ch}").models[0].uniq_resid)) > 0:
types[ch] = "dna"
elif len(np.unique(coor.select_atoms(f"rna and chain {ch}").models[0].uniq_resid)) > 0:
types[ch] = "rna"
else:
types[ch] = "other"
return types
[docs]
def rmsd(coor_1, coor_2, selection="name CA", index_list=None, frame_ref=0):
"""Compute RMSD between two sets of coordinates.
Parameters
----------
coor_1 : Coor
First set of coordinates.
coor_2 : Coor
Second set of coordinates.
selection : str, optional
Selection string used when index_list is not provided.
index_list : list, optional
Pair of index lists [index_1, index_2].
frame_ref : int, optional
Reference frame index in coor_2.
Returns
-------
list[float]
RMSD values for each model in coor_1.
"""
if frame_ref < 0 or frame_ref >= coor_2.model_num:
raise ValueError(
"Reference frame index is larger than the number of frames in the reference structure"
)
if index_list is None:
index_1 = coor_1.get_index_select(selection)
index_2 = coor_2.get_index_select(selection)
else:
index_1 = index_list[0]
index_2 = index_list[1]
if len(index_1) == 0 or len(index_2) == 0:
raise ValueError("No atoms selected for RMSD calculation")
if len(index_1) != len(index_2):
raise ValueError("Index lists must have the same length")
return core_rmsd(coor_1, coor_2, list(index_1), list(index_2), frame_ref=frame_ref)
[docs]
def interface_rmsd(
coor,
coor_native,
rec_chains_native,
lig_chains_native,
cutoff=10.0,
back_atom=None,
index_pair=None,
):
"""Compute the interface RMSD between two models.
The interface is defined as atoms within ``cutoff`` Angstrom of the
opposite partner chain(s) in the native structure. The RMSD is computed
on the selected backbone atoms after aligning the model to the native
structure using the interface atoms.
Parameters
----------
coor : Coor
Model coordinates.
coor_native : Coor
Native coordinates.
rec_chains_native : list[str]
Native receptor chains.
lig_chains_native : list[str]
Native ligand chains.
cutoff : float, default=10.0
Interface distance cutoff in Angstrom.
back_atom : list[str], optional
Backbone atom names used for RMSD (default: ``["CA", "N", "C", "O"]``).
index_pair : tuple[list[int], list[int]], optional
Pre-computed ``(model_indices, native_indices)`` of the interface
backbone atoms. When supplied the function skips the chain-selection
and residue-mapping steps entirely, calling :func:`align_index_based`
directly. This correctly handles non-sequential or non-contiguous
index lists (e.g. when model chains are numbered from 1 on every chain
and thus have overlapping ``resid`` values).
Returns
-------
list[float]
Interface RMSD values for each model. Returns ``None`` entries when
no interface residues are found.
"""
if back_atom is None:
back_atom = ["CA", "N", "C", "O"]
if index_pair is not None:
iface_model_idx, iface_native_idx = index_pair
if not iface_model_idx:
return [None] * len(coor.models)
rmsds, _, _ = align_index_based(coor, coor_native, iface_model_idx, iface_native_idx)
return rmsds
lig_interface = coor_native.select_atoms(
f"chain {' '.join(lig_chains_native)} and within {cutoff} of chain {' '.join(rec_chains_native)}"
)
rec_interface = coor_native.select_atoms(
f"chain {' '.join(rec_chains_native)} and within {cutoff} of chain {' '.join(lig_chains_native)}"
)
lig_interface_residues = np.unique(lig_interface.models[0].uniq_resid)
rec_interface_residues = np.unique(rec_interface.models[0].uniq_resid)
interface_residues = np.concatenate(
(lig_interface_residues, rec_interface_residues)
)
if len(interface_residues) == 0:
logger.info("No interface residues found")
return [None] * len(coor.models)
native_unique = list(dict.fromkeys(coor_native.models[0].uniq_resid))
coor_unique = list(dict.fromkeys(coor.models[0].uniq_resid))
if len(native_unique) != len(coor_unique):
raise ValueError("Interface selections do not contain the same residue count")
native_to_coor = {nat: coor_unique[i] for i, nat in enumerate(native_unique)}
mapped_interface = [native_to_coor[nat] for nat in interface_residues]
native_list = " ".join(str(i) for i in interface_residues)
coor_list = " ".join(str(i) for i in mapped_interface)
back_list = " ".join(back_atom)
index = coor.get_index_select(f"residue {coor_list} and name {back_list}")
index_native = coor_native.get_index_select(
f"residue {native_list} and name {back_list}"
)
if len(index) != len(index_native):
raise ValueError(
"The number of atoms in the interface is not the same in the two models"
)
rmsds, _, _ = align_index_based(coor, coor_native, index, index_native)
return rmsds
[docs]
def dockQ(
coor,
native_coor,
rec_chains=None,
lig_chains=None,
native_rec_chains=None,
native_lig_chains=None,
back_atom=None,
_search_mode=False,
):
"""Compute DockQ scores between a model and a native structure.
DockQ combines interface contacts (Fnat), ligand RMSD (LRMS), and
interface RMSD (iRMS) into a single docking quality metric. Chain roles
are inferred by selecting the shortest chain as ligand when not provided.
Parameters
----------
coor : Coor
Model coordinates.
native_coor : Coor
Native coordinates.
rec_chains : list[str], optional
Model receptor chains. If ``None``, uses all chains except the
shortest chain (ligand).
lig_chains : list[str], optional
Model ligand chains. If ``None``, uses the shortest chain.
native_rec_chains : list[str], optional
Native receptor chains. If ``None``, uses all chains except the
shortest chain (ligand).
native_lig_chains : list[str], optional
Native ligand chains. If ``None``, uses the shortest chain.
back_atom : list[str], optional
Backbone atom names used for alignment and RMSD calculations.
Returns
-------
dict
Dictionary with keys ``Fnat``, ``Fnonnat``, ``rRMS``, ``iRMS``,
``LRMS``, and ``DockQ``, each containing lists per model.
Notes
-----
This implementation mirrors the pdb_numpy DockQ pipeline and relies on
sequence-based alignment for receptor superposition before computing the
ligand RMSD and interface metrics.
"""
if back_atom is None:
back_atom = ["CA", "N", "C", "O"]
model_seq = coor.get_aa_seq()
native_seq = native_coor.get_aa_seq()
if lig_chains is None:
lig_chains = [
min(model_seq.items(), key=lambda x: len(x[1].replace("-", "")))[0]
]
logger.info("Model ligand chains : %s", " ".join(lig_chains))
if rec_chains is None:
rec_chains = [chain for chain in model_seq if chain not in lig_chains]
logger.info("Model receptor chains : %s", " ".join(rec_chains))
if native_lig_chains is None:
native_lig_chains = [
min(native_seq.items(), key=lambda x: len(x[1].replace("-", "")))[0]
]
logger.info("Native ligand chains : %s", " ".join(native_lig_chains))
if native_rec_chains is None:
native_rec_chains = [
chain for chain in native_seq if chain not in native_lig_chains
]
logger.info("Native receptor chains : %s", " ".join(native_rec_chains))
clean_coor = coor.select_atoms(
f"protein and not altloc B C D and not name H* and chain {' '.join(rec_chains + lig_chains)}"
)
clean_native_coor = native_coor.select_atoms(
f"protein and not altloc B C D and not name H* and chain {' '.join(native_rec_chains + native_lig_chains)}"
)
clean_coor = remove_incomplete_backbone_residues(clean_coor, back_atom)
clean_native_coor = remove_incomplete_backbone_residues(
clean_native_coor, back_atom
)
rmsd_prot_list, align_rec_index, align_rec_native_index = align_seq_based(
clean_coor,
clean_native_coor,
chain_1=rec_chains,
chain_2=native_rec_chains,
back_names=back_atom,
)
logger.info("Receptor RMSD: %.3f A", rmsd_prot_list[0])
logger.info(
"Found %d residues in common (receptor)",
len(align_rec_index) // len(back_atom),
)
lig_index, lig_native_index = get_common_atoms(
clean_coor,
clean_native_coor,
chain_1=lig_chains,
chain_2=native_lig_chains,
back_names=back_atom,
matrix_file="",
)
lrmsd_list = rmsd(
clean_coor, clean_native_coor, index_list=[lig_index, lig_native_index]
)
logger.info("Ligand RMSD: %.3f A", lrmsd_list[0])
logger.info(
"Found %d residues in common (ligand)",
len(lig_index) // len(back_atom),
)
if _search_mode:
def _scale(rms, d):
return 0.0 if rms is None else 1.0 / (1.0 + (rms / d) ** 2)
pseudo_dq = [_scale(lrms, 8.5) / 3.0 for lrms in lrmsd_list]
return {
"Fnat": [0.0] * len(lrmsd_list),
"Fnonnat": [0.0] * len(lrmsd_list),
"rRMS": rmsd_prot_list,
"iRMS": [None] * len(lrmsd_list),
"LRMS": lrmsd_list,
"DockQ": pseudo_dq,
"clashes": [0] * len(lrmsd_list),
}
coor_residue = np.asarray(clean_coor.models[0].uniq_resid)[
np.asarray(align_rec_index + lig_index, dtype=int)
]
native_residue = np.asarray(clean_native_coor.models[0].uniq_resid)[
np.asarray(align_rec_native_index + lig_native_index, dtype=int)
]
coor_residue_unique = np.unique(coor_residue)
native_residue_unique = np.unique(native_residue)
if len(coor_residue_unique) != len(native_residue_unique):
raise ValueError("Residue alignment mismatch between model and native")
interface_coor = clean_coor.select_atoms(
f"residue {' '.join(str(i) for i in coor_residue_unique)}"
)
interface_native_coor = clean_native_coor.select_atoms(
f"residue {' '.join(str(i) for i in native_residue_unique)}"
)
residue_id_map = {}
native_residue_id_map = {}
common_residue_ids = {}
common_id = 1
for model_residue_id, native_residue_id in zip(coor_residue, native_residue):
if native_residue_id not in common_residue_ids:
common_residue_ids[native_residue_id] = common_id
common_id += 1
mapped_id = common_residue_ids[native_residue_id]
residue_id_map[model_residue_id] = mapped_id
native_residue_id_map[native_residue_id] = mapped_id
nat_lig_iface = clean_native_coor.select_atoms(
f"chain {' '.join(native_lig_chains)} "
f"and within 10.0 of chain {' '.join(native_rec_chains)}"
)
nat_rec_iface = clean_native_coor.select_atoms(
f"chain {' '.join(native_rec_chains)} "
f"and within 10.0 of chain {' '.join(native_lig_chains)}"
)
iface_native_uids = set(
list(nat_lig_iface.models[0].uniq_resid)
+ list(nat_rec_iface.models[0].uniq_resid)
)
native_all_uids = np.asarray(clean_native_coor.models[0].uniq_resid)
all_model_idx = align_rec_index + lig_index
all_native_idx = align_rec_native_index + lig_native_index
iface_model_idx = [
mi for mi, ni in zip(all_model_idx, all_native_idx)
if native_all_uids[ni] in iface_native_uids
]
iface_native_idx = [
ni for mi, ni in zip(all_model_idx, all_native_idx)
if native_all_uids[ni] in iface_native_uids
]
fnat_list, fnonnat_list = native_contact(
interface_coor,
interface_native_coor,
rec_chains,
lig_chains,
native_rec_chains,
native_lig_chains,
cutoff=5.0,
residue_id_map=residue_id_map,
native_residue_id_map=native_residue_id_map,
)
logger.info("Fnat: %.3f Fnonnat: %.3f", fnat_list[0], fnonnat_list[0])
if _search_mode:
clashes_list = [0] * len(fnat_list)
else:
clashes_list = _count_clashes(interface_coor, rec_chains, lig_chains)
irmsd_list = interface_rmsd(
clean_coor,
clean_native_coor,
native_rec_chains,
native_lig_chains,
cutoff=10.0,
back_atom=back_atom,
index_pair=(iface_model_idx, iface_native_idx) if iface_model_idx else None,
)
logger.info("Interface RMSD: %s A", irmsd_list[0])
logger.info("Clashes: %d", clashes_list[0])
def scale_rms(rms, d):
if rms is None:
return 0.0
return 1.0 / (1 + (rms / d) ** 2)
d1 = 8.5
d2 = 1.5
dockq_list = [
(fnat + scale_rms(lrmsd, d1) + scale_rms(irmsd, d2)) / 3
for fnat, lrmsd, irmsd in zip(fnat_list, lrmsd_list, irmsd_list)
]
logger.info("DockQ Score pdb_cpp: %.3f", dockq_list[0])
return {
"Fnat": fnat_list,
"Fnonnat": fnonnat_list,
"rRMS": rmsd_prot_list,
"iRMS": irmsd_list,
"LRMS": lrmsd_list,
"DockQ": dockq_list,
"clashes": clashes_list,
}
[docs]
def dockQ_multimer(
coor,
native_coor,
chain_map=None,
back_atom=None,
n_cpu=1,
_search_mode=False,
_native_iface_cache=None,
_model_backbone_cache=None,
):
"""Compute DockQ over all pairwise native chain interfaces (multimer).
Scores every :math:`\\binom{n}{2}` interface between the *n* native chains
and returns per-interface DockQ metrics as well as *GlobalDockQ* (the
average DockQ over all interfaces), mirroring the DockQ v2 multimer output.
Parameters
----------
coor : Coor
Model coordinates.
native_coor : Coor
Native coordinates.
chain_map : dict[str, str], optional
Mapping from **native** chain IDs to **model** chain IDs. If ``None``
the function assumes that both structures share the same chain names
and builds an identity mapping for chains present in both.
back_atom : list[str], optional
Backbone atom names used for alignment and RMSD calculations.
Returns
-------
dict
Dictionary with two keys:
``"interfaces"``
``dict[(native_ch1, native_ch2), result]`` where *result* is the
dict returned by :func:`dockQ` for that pair (or ``None`` when the
interface could not be scored).
``"GlobalDockQ"``
``list[float]`` — average DockQ over all valid interfaces,
one value per model frame.
Notes
-----
The larger native chain of each pair is used as receptor to match the
DockQ v2 convention. For chains of equal length the pair is presented in
the order they appear in the ``chain_map`` iteration order.
"""
if back_atom is None:
back_atom = ["CA", "N", "C", "O"]
native_seq = native_coor.get_aa_seq()
model_seq = coor.get_aa_seq()
_native_chain_types = {ch: "protein" for ch in native_seq}
_model_chain_types = {ch: "protein" for ch in model_seq}
_native_iface_cache_local = _native_iface_cache
_model_backbone_cache_local = _model_backbone_cache
if chain_map is None:
native_chains_ordered = sorted(native_seq.keys())
model_chains_list = sorted(model_seq.keys())
_MAX_PERMUTATIONS = 720
_all_native = sorted(native_seq.keys())
_native_iface_cache_local = {}
for _n1, _n2 in itertools.combinations(_all_native, 2):
_sel = native_coor.select_atoms(f"chain {_n1} and within 10.0 of chain {_n2}")
_native_iface_cache_local[(_n1, _n2)] = (
len(np.unique(_sel.models[0].uniq_resid)) > 0
)
_back_list = " ".join(back_atom)
_model_backbone_cache_local = {}
for _mc in model_chains_list:
_sel = coor.select_atoms(f"protein and chain {_mc} and name {_back_list}")
_model_backbone_cache_local[_mc] = len(np.unique(_sel.models[0].uniq_resid)) > 0
if len(model_chains_list) < len(native_chains_ordered):
clusters = {}
for mod_chain in model_chains_list:
mod_s = model_seq[mod_chain].replace("-", "").upper()
_mod_type = _model_chain_types.get(mod_chain, "protein")
_same_type_nat = [
nc for nc in native_chains_ordered
if _native_chain_types.get(nc) == _mod_type
]
candidates = [
nc for nc in _same_type_nat
if _sequence_identity(mod_s, native_seq[nc].replace("-", "").upper()) >= 0.9
]
if not candidates:
_pool = _same_type_nat if _same_type_nat else native_chains_ordered
candidates = [max(
_pool,
key=lambda nc, ms=mod_s: _sequence_identity(
ms, native_seq[nc].replace("-", "").upper()
),
)]
clusters[mod_chain] = candidates
_model_cands_for_nat = {}
for _mc, _nats in clusters.items():
for _nc in _nats:
_model_cands_for_nat.setdefault(_nc, []).append(_mc)
else:
clusters = {}
for nat_chain in native_chains_ordered:
nat_s = native_seq[nat_chain].replace("-", "").upper()
_nat_type = _native_chain_types.get(nat_chain, "protein")
_same_type_mod = [
mc for mc in model_chains_list
if _model_chain_types.get(mc) == _nat_type
]
candidates = [
mc for mc in _same_type_mod
if _sequence_identity(nat_s, model_seq[mc].replace("-", "").upper()) >= 0.9
]
if not candidates:
_pool = _same_type_mod if _same_type_mod else model_chains_list
candidates = [max(
_pool,
key=lambda mc, ns=nat_s: _sequence_identity(
ns, model_seq[mc].replace("-", "").upper()
),
)]
clusters[nat_chain] = candidates
_model_cands_for_nat = clusters
if len(model_chains_list) < len(native_chains_ordered):
def _gen_maps_inv(mod_chains, used):
"""Yield model→native dicts; each native chain used at most once."""
if not mod_chains:
yield {}
return
mod = mod_chains[0]
for nat in clusters[mod]:
if nat not in used:
for rest in _gen_maps_inv(mod_chains[1:], used | {nat}):
yield {mod: nat, **rest}
n_combos = _count_chain_combinations(clusters)
logger.info(
"Chain-map search (inverse): %d permutation(s) to evaluate", n_combos
)
if n_combos == 1:
best_map = next(
{v: k for k, v in m2n.items()}
for m2n in _gen_maps_inv(model_chains_list, set())
)
logger.info("Single valid chain mapping (inverse): %s", best_map)
else:
_native_pairs_with_iface = [
(_n1, _n2)
for _n1, _n2 in itertools.combinations(_all_native, 2)
if _native_iface_cache_local.get((_n1, _n2), False)
]
_pairwise_cache = {}
for _n1, _n2 in _native_pairs_with_iface:
if (_native_chain_types.get(_n1) != "protein"
or _native_chain_types.get(_n2) != "protein"):
continue
_n1_len = len(native_seq.get(_n1, "").replace("-", ""))
_n2_len = len(native_seq.get(_n2, "").replace("-", ""))
_rec_n, _lig_n = (_n1, _n2) if _n1_len >= _n2_len else (_n2, _n1)
for _rec_m in _model_cands_for_nat.get(_rec_n, []):
if not _model_backbone_cache_local.get(_rec_m, True):
continue
for _lig_m in _model_cands_for_nat.get(_lig_n, []):
if _rec_m == _lig_m:
continue
if not _model_backbone_cache_local.get(_lig_m, True):
continue
_key = (_rec_m, _lig_m, _rec_n, _lig_n)
if _key in _pairwise_cache:
continue
try:
_res = dockQ(
coor, native_coor,
rec_chains=[_rec_m], lig_chains=[_lig_m],
native_rec_chains=[_rec_n], native_lig_chains=[_lig_n],
back_atom=back_atom, _search_mode=True,
)
_pairwise_cache[_key] = _res["DockQ"][0]
except Exception as _exc:
logger.debug(
"Pairwise precomp (%s,%s)→(%s,%s) failed: %s",
_rec_m, _lig_m, _rec_n, _lig_n, _exc,
)
_pairwise_cache[_key] = 0.0
logger.info(
"Precomputed %d pairwise scores for %d native interface pairs",
len(_pairwise_cache), len(_native_pairs_with_iface),
)
def _score_map(candidate_map):
total = 0.0
for _n1, _n2 in _native_pairs_with_iface:
_m1 = candidate_map.get(_n1)
_m2 = candidate_map.get(_n2)
if _m1 is None or _m2 is None or _m1 == _m2:
continue
_n1_len = len(native_seq.get(_n1, "").replace("-", ""))
_n2_len = len(native_seq.get(_n2, "").replace("-", ""))
if _n1_len >= _n2_len:
_rec_m, _lig_m, _rec_n, _lig_n = _m1, _m2, _n1, _n2
else:
_rec_m, _lig_m, _rec_n, _lig_n = _m2, _m1, _n2, _n1
total += _pairwise_cache.get((_rec_m, _lig_m, _rec_n, _lig_n), 0.0)
return total
def _find_best_map(all_maps, initial_map=None):
best_total = -1.0
best_map = None
if initial_map is not None:
best_total = _score_map(initial_map)
best_map = initial_map if best_total >= 0 else None
logger.info("Identity mapping %s scores %.3f", initial_map, best_total)
maps_list = [m for m in all_maps if m != initial_map]
if len(maps_list) > _MAX_PERMUTATIONS:
logger.warning(
"Stopping chain-map search after %d permutations (found %d); "
"provide an explicit chain_map for an exhaustive search.",
_MAX_PERMUTATIONS, len(maps_list),
)
maps_list = maps_list[:_MAX_PERMUTATIONS]
if not maps_list:
return best_map, best_total
for cmap in maps_list:
total = _score_map(cmap)
if total > best_total:
best_total = total
best_map = cmap
return best_map, best_total
all_n2m = [
{v: k for k, v in m2n.items()}
for m2n in _gen_maps_inv(model_chains_list, set())
]
best_map, _ = _find_best_map(all_n2m)
else:
def _gen_maps(nat_chains, used):
if not nat_chains:
yield {}
return
nat = nat_chains[0]
for mod in clusters[nat]:
if mod not in used:
for rest in _gen_maps(nat_chains[1:], used | {mod}):
yield {nat: mod, **rest}
n_combos = _count_chain_combinations(clusters)
logger.info(
"Chain-map search: %d permutation(s) to evaluate", n_combos
)
identity_map = {
nat: (nat if nat in clusters[nat] else clusters[nat][0])
for nat in native_chains_ordered
}
if n_combos == 1:
best_map = identity_map
logger.info("Single valid chain mapping: %s", best_map)
else:
_native_pairs_with_iface = [
(_n1, _n2)
for _n1, _n2 in itertools.combinations(_all_native, 2)
if _native_iface_cache_local.get((_n1, _n2), False)
]
_pairwise_cache = {}
for _n1, _n2 in _native_pairs_with_iface:
if (_native_chain_types.get(_n1) != "protein"
or _native_chain_types.get(_n2) != "protein"):
continue
_n1_len = len(native_seq.get(_n1, "").replace("-", ""))
_n2_len = len(native_seq.get(_n2, "").replace("-", ""))
_rec_n, _lig_n = (_n1, _n2) if _n1_len >= _n2_len else (_n2, _n1)
for _rec_m in _model_cands_for_nat.get(_rec_n, []):
if not _model_backbone_cache_local.get(_rec_m, True):
continue
for _lig_m in _model_cands_for_nat.get(_lig_n, []):
if _rec_m == _lig_m:
continue
if not _model_backbone_cache_local.get(_lig_m, True):
continue
_key = (_rec_m, _lig_m, _rec_n, _lig_n)
if _key in _pairwise_cache:
continue
try:
_res = dockQ(
coor, native_coor,
rec_chains=[_rec_m], lig_chains=[_lig_m],
native_rec_chains=[_rec_n], native_lig_chains=[_lig_n],
back_atom=back_atom, _search_mode=True,
)
_pairwise_cache[_key] = _res["DockQ"][0]
except Exception as _exc:
logger.debug(
"Pairwise precomp (%s,%s)→(%s,%s) failed: %s",
_rec_m, _lig_m, _rec_n, _lig_n, _exc,
)
_pairwise_cache[_key] = 0.0
logger.info(
"Precomputed %d pairwise scores for %d native interface pairs",
len(_pairwise_cache), len(_native_pairs_with_iface),
)
def _score_map(candidate_map):
total = 0.0
for _n1, _n2 in _native_pairs_with_iface:
_m1 = candidate_map.get(_n1)
_m2 = candidate_map.get(_n2)
if _m1 is None or _m2 is None or _m1 == _m2:
continue
_n1_len = len(native_seq.get(_n1, "").replace("-", ""))
_n2_len = len(native_seq.get(_n2, "").replace("-", ""))
if _n1_len >= _n2_len:
_rec_m, _lig_m, _rec_n, _lig_n = _m1, _m2, _n1, _n2
else:
_rec_m, _lig_m, _rec_n, _lig_n = _m2, _m1, _n2, _n1
total += _pairwise_cache.get((_rec_m, _lig_m, _rec_n, _lig_n), 0.0)
return total
def _find_best_map(all_maps, initial_map=None):
best_total = -1.0
best_map = None
if initial_map is not None:
best_total = _score_map(initial_map)
best_map = initial_map if best_total >= 0 else None
logger.info("Identity mapping %s scores %.3f", initial_map, best_total)
maps_list = [m for m in all_maps if m != initial_map]
if len(maps_list) > _MAX_PERMUTATIONS:
logger.warning(
"Stopping chain-map search after %d permutations (found %d); "
"provide an explicit chain_map for an exhaustive search.",
_MAX_PERMUTATIONS, len(maps_list),
)
maps_list = maps_list[:_MAX_PERMUTATIONS]
if not maps_list:
return best_map, best_total
for cmap in maps_list:
total = _score_map(cmap)
if total > best_total:
best_total = total
best_map = cmap
return best_map, best_total
all_nat_maps = list(_gen_maps(native_chains_ordered, set()))
best_map, _ = _find_best_map(all_nat_maps, initial_map=identity_map)
if best_map is None:
best_map = {ch: ch for ch in native_seq if ch in model_seq}
logger.warning("No valid chain mapping found; falling back to identity mapping.")
chain_map = best_map
logger.info("Optimal chain mapping: %s", chain_map)
native_chains = sorted(chain_map.keys())
interfaces = {}
for n1, n2 in itertools.combinations(native_chains, 2):
m1 = chain_map[n1]
m2 = chain_map[n2]
if m1 == m2:
logger.info(
"Skipping interface (%s, %s): both map to the same model chain %s",
n1, n2, m1,
)
interfaces[(n1, n2)] = None
continue
if (_native_chain_types.get(n1) != "protein"
or _native_chain_types.get(n2) != "protein"):
logger.info(
"Skipping interface (%s:%s, %s:%s): non-protein interfaces not yet supported",
n1, _native_chain_types.get(n1, "?"),
n2, _native_chain_types.get(n2, "?"),
)
interfaces[(n1, n2)] = None
continue
n1_len = len(native_seq.get(n1, "").replace("-", ""))
n2_len = len(native_seq.get(n2, "").replace("-", ""))
if n1_len >= n2_len:
rec_n, lig_n, rec_m, lig_m = n1, n2, m1, m2
else:
rec_n, lig_n, rec_m, lig_m = n2, n1, m2, m1
logger.info(
"Scoring interface native(%s, %s) -> model(%s, %s)",
rec_n, lig_n, rec_m, lig_m,
)
_eff_nat_cache = _native_iface_cache_local
_cached_nat = (
_eff_nat_cache.get((n1, n2), _eff_nat_cache.get((n2, n1)))
if _eff_nat_cache is not None else None
)
viable, reason = _check_interface_viable(
coor, native_coor, rec_m, lig_m, rec_n, lig_n, back_atom,
_native_has_interface=_cached_nat,
_model_backbone_cache=_model_backbone_cache_local,
)
if not viable:
logger.info("Skipping interface (%s, %s): %s", n1, n2, reason)
interfaces[(n1, n2)] = None
continue
try:
result = dockQ(
coor,
native_coor,
rec_chains=[rec_m],
lig_chains=[lig_m],
native_rec_chains=[rec_n],
native_lig_chains=[lig_n],
back_atom=back_atom,
_search_mode=_search_mode,
)
interfaces[(n1, n2)] = {
**result,
"model_rec_chain": rec_m,
"model_lig_chain": lig_m,
}
except Exception as exc:
logger.info(
"Could not compute DockQ for interface (%s, %s): %s", n1, n2, exc
)
interfaces[(n1, n2)] = None
valid = [v for v in interfaces.values() if v is not None]
if valid:
n_frames = len(valid[0]["DockQ"])
global_dockq = [
sum(v["DockQ"][i] for v in valid) / len(valid)
for i in range(n_frames)
]
else:
logger.warning("No valid interfaces found; GlobalDockQ is undefined.")
global_dockq = []
logger.info("GlobalDockQ (avg over %d interfaces): %.3f", len(valid), global_dockq[0] if global_dockq else float("nan"))
return {
"interfaces": interfaces,
"GlobalDockQ": global_dockq,
"chain_map": chain_map,
}